Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions sqlmesh/core/engine_adapter/postgres.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import logging
import re
import typing as t
from functools import partial
from functools import cached_property, partial
from sqlglot import exp

from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter
Expand Down Expand Up @@ -112,11 +113,8 @@ def merge(
**kwargs: t.Any,
) -> None:
# Merge isn't supported until Postgres 15
merge_impl = (
super().merge
if self._connection_pool.get().server_version >= 150000
else partial(logical_merge, self)
)
major, minor = self.server_version
merge_impl = super().merge if major >= 15 else partial(logical_merge, self)
merge_impl( # type: ignore
target_table,
source_table,
Expand All @@ -125,3 +123,13 @@ def merge(
when_matched=when_matched,
merge_filter=merge_filter,
)

@cached_property
def server_version(self) -> t.Tuple[int, int]:
"""Lazily fetch and cache major and minor server version"""
if result := self.fetchone("SHOW server_version"):
server_version, *_ = result
match = re.search(r"(\d+)\.(\d+)", server_version)
if match:
return int(match.group(1)), int(match.group(2))
return 0, 0
16 changes: 11 additions & 5 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,25 +372,31 @@ def update_auto_restatements(
Args:
next_auto_restatement_ts: A dictionary of snapshot name version to the next auto restatement timestamp.
"""
next_auto_restatement_ts_deleted = []
next_auto_restatement_ts_filtered = {}
for k, v in next_auto_restatement_ts.items():
if v is None:
next_auto_restatement_ts_deleted.append(k)
else:
next_auto_restatement_ts_filtered[k] = v

for where in snapshot_name_version_filter(
self.engine_adapter,
next_auto_restatement_ts,
next_auto_restatement_ts_deleted,
column_prefix="snapshot",
alias=None,
batch_size=self.SNAPSHOT_BATCH_SIZE,
):
self.engine_adapter.delete_from(self.auto_restatements_table, where=where)

next_auto_restatement_ts_filtered = {
k: v for k, v in next_auto_restatement_ts.items() if v is not None
}
if not next_auto_restatement_ts_filtered:
return

self.engine_adapter.insert_append(
self.engine_adapter.merge(
self.auto_restatements_table,
_auto_restatements_to_df(next_auto_restatement_ts_filtered),
columns_to_types=self._auto_restatement_columns_to_types,
unique_key=(exp.column("snapshot_name"), exp.column("snapshot_version")),
)

def count(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest
from pytest import FixtureRequest
from sqlmesh.core.engine_adapter import PostgresEngineAdapter
from tests.core.engine_adapter.integration import TestContext

from tests.core.engine_adapter.integration import (
TestContext,
Expand All @@ -29,3 +28,8 @@ def engine_adapter(ctx: TestContext) -> PostgresEngineAdapter:
def test_engine_adapter(ctx: TestContext):
assert isinstance(ctx.engine_adapter, PostgresEngineAdapter)
assert ctx.engine_adapter.fetchone("select 1") == (1,)


def test_server_version_psycopg(ctx: TestContext):
assert isinstance(ctx.engine_adapter, PostgresEngineAdapter)
assert ctx.engine_adapter.server_version != (0, 0)
20 changes: 18 additions & 2 deletions tests/core/engine_adapter/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_create_table_like(make_mocked_engine_adapter: t.Callable):

def test_merge_version_gte_15(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
adapter._connection_pool.get().server_version = 150000
adapter.server_version = (15, 0)

adapter.merge(
target_table="target",
Expand All @@ -117,7 +117,7 @@ def test_merge_version_lt_15(
make_mocked_engine_adapter: t.Callable, make_temp_table_name: t.Callable, mocker: MockerFixture
):
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
adapter._connection_pool.get().server_version = 140000
adapter.server_version = (14, 0)

temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
table_name = "test"
Expand Down Expand Up @@ -161,3 +161,19 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]:
assert to_sql_calls(adapter) == [
'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE',
]


def test_server_version(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)

fetchone_mock = mocker.patch.object(adapter, "fetchone")
fetchone_mock.return_value = ("14.0",)
assert adapter.server_version == (14, 0)

del adapter.server_version
fetchone_mock.return_value = ("15.8",)
assert adapter.server_version == (15, 8)

del adapter.server_version
fetchone_mock.return_value = ("15.13 (Debian 15.13-1.pgdg120+1)",)
assert adapter.server_version == (15, 13)
Loading