diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index 9962c037ac..a736f5553b 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -1,8 +1,9 @@ from __future__ import annotations import logging +import re import typing as t -from functools import partial +from functools import cached_property, partial from sqlglot import exp from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter @@ -112,11 +113,8 @@ def merge( **kwargs: t.Any, ) -> None: # Merge isn't supported until Postgres 15 - merge_impl = ( - super().merge - if self._connection_pool.get().server_version >= 150000 - else partial(logical_merge, self) - ) + major, minor = self.server_version + merge_impl = super().merge if major >= 15 else partial(logical_merge, self) merge_impl( # type: ignore target_table, source_table, @@ -125,3 +123,13 @@ def merge( when_matched=when_matched, merge_filter=merge_filter, ) + + @cached_property + def server_version(self) -> t.Tuple[int, int]: + """Lazily fetch and cache major and minor server version""" + if result := self.fetchone("SHOW server_version"): + server_version, *_ = result + match = re.search(r"(\d+)\.(\d+)", server_version) + if match: + return int(match.group(1)), int(match.group(2)) + return 0, 0 diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 6064993087..30e0de00f2 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -372,25 +372,31 @@ def update_auto_restatements( Args: next_auto_restatement_ts: A dictionary of snapshot name version to the next auto restatement timestamp. """ + next_auto_restatement_ts_deleted = [] + next_auto_restatement_ts_filtered = {} + for k, v in next_auto_restatement_ts.items(): + if v is None: + next_auto_restatement_ts_deleted.append(k) + else: + next_auto_restatement_ts_filtered[k] = v + for where in snapshot_name_version_filter( self.engine_adapter, - next_auto_restatement_ts, + next_auto_restatement_ts_deleted, column_prefix="snapshot", alias=None, batch_size=self.SNAPSHOT_BATCH_SIZE, ): self.engine_adapter.delete_from(self.auto_restatements_table, where=where) - next_auto_restatement_ts_filtered = { - k: v for k, v in next_auto_restatement_ts.items() if v is not None - } if not next_auto_restatement_ts_filtered: return - self.engine_adapter.insert_append( + self.engine_adapter.merge( self.auto_restatements_table, _auto_restatements_to_df(next_auto_restatement_ts_filtered), columns_to_types=self._auto_restatement_columns_to_types, + unique_key=(exp.column("snapshot_name"), exp.column("snapshot_version")), ) def count(self) -> int: diff --git a/tests/core/engine_adapter/integration/test_integration_postgres.py b/tests/core/engine_adapter/integration/test_integration_postgres.py index 863aae55a4..82172378ae 100644 --- a/tests/core/engine_adapter/integration/test_integration_postgres.py +++ b/tests/core/engine_adapter/integration/test_integration_postgres.py @@ -2,7 +2,6 @@ import pytest from pytest import FixtureRequest from sqlmesh.core.engine_adapter import PostgresEngineAdapter -from tests.core.engine_adapter.integration import TestContext from tests.core.engine_adapter.integration import ( TestContext, @@ -29,3 +28,8 @@ def engine_adapter(ctx: TestContext) -> PostgresEngineAdapter: def test_engine_adapter(ctx: TestContext): assert isinstance(ctx.engine_adapter, PostgresEngineAdapter) assert ctx.engine_adapter.fetchone("select 1") == (1,) + + +def test_server_version_psycopg(ctx: TestContext): + assert isinstance(ctx.engine_adapter, PostgresEngineAdapter) + assert ctx.engine_adapter.server_version != (0, 0) diff --git a/tests/core/engine_adapter/test_postgres.py b/tests/core/engine_adapter/test_postgres.py index f013914c3e..fd6ce44994 100644 --- a/tests/core/engine_adapter/test_postgres.py +++ b/tests/core/engine_adapter/test_postgres.py @@ -94,7 +94,7 @@ def test_create_table_like(make_mocked_engine_adapter: t.Callable): def test_merge_version_gte_15(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(PostgresEngineAdapter) - adapter._connection_pool.get().server_version = 150000 + adapter.server_version = (15, 0) adapter.merge( target_table="target", @@ -117,7 +117,7 @@ def test_merge_version_lt_15( make_mocked_engine_adapter: t.Callable, make_temp_table_name: t.Callable, mocker: MockerFixture ): adapter = make_mocked_engine_adapter(PostgresEngineAdapter) - adapter._connection_pool.get().server_version = 140000 + adapter.server_version = (14, 0) temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") table_name = "test" @@ -161,3 +161,19 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: assert to_sql_calls(adapter) == [ 'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE', ] + + +def test_server_version(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + + fetchone_mock = mocker.patch.object(adapter, "fetchone") + fetchone_mock.return_value = ("14.0",) + assert adapter.server_version == (14, 0) + + del adapter.server_version + fetchone_mock.return_value = ("15.8",) + assert adapter.server_version == (15, 8) + + del adapter.server_version + fetchone_mock.return_value = ("15.13 (Debian 15.13-1.pgdg120+1)",) + assert adapter.server_version == (15, 13)