Skip to content

Commit d62e6bb

Browse files
authored
Merge branch 'main' into vchan/gcp-postgres-integration-tests
2 parents 3f0576c + 9afc728 commit d62e6bb

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

sqlmesh/core/engine_adapter/postgres.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import logging
4+
import re
45
import typing as t
5-
from functools import partial
6+
from functools import cached_property, partial
67
from sqlglot import exp
78

89
from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter
@@ -112,11 +113,8 @@ def merge(
112113
**kwargs: t.Any,
113114
) -> None:
114115
# Merge isn't supported until Postgres 15
115-
merge_impl = (
116-
super().merge
117-
if self._connection_pool.get().server_version >= 150000
118-
else partial(logical_merge, self)
119-
)
116+
major, minor = self.server_version
117+
merge_impl = super().merge if major >= 15 else partial(logical_merge, self)
120118
merge_impl( # type: ignore
121119
target_table,
122120
source_table,
@@ -125,3 +123,13 @@ def merge(
125123
when_matched=when_matched,
126124
merge_filter=merge_filter,
127125
)
126+
127+
@cached_property
128+
def server_version(self) -> t.Tuple[int, int]:
129+
"""Lazily fetch and cache major and minor server version"""
130+
if result := self.fetchone("SHOW server_version"):
131+
server_version, *_ = result
132+
match = re.search(r"(\d+)\.(\d+)", server_version)
133+
if match:
134+
return int(match.group(1)), int(match.group(2))
135+
return 0, 0

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,25 +372,31 @@ def update_auto_restatements(
372372
Args:
373373
next_auto_restatement_ts: A dictionary of snapshot name version to the next auto restatement timestamp.
374374
"""
375+
next_auto_restatement_ts_deleted = []
376+
next_auto_restatement_ts_filtered = {}
377+
for k, v in next_auto_restatement_ts.items():
378+
if v is None:
379+
next_auto_restatement_ts_deleted.append(k)
380+
else:
381+
next_auto_restatement_ts_filtered[k] = v
382+
375383
for where in snapshot_name_version_filter(
376384
self.engine_adapter,
377-
next_auto_restatement_ts,
385+
next_auto_restatement_ts_deleted,
378386
column_prefix="snapshot",
379387
alias=None,
380388
batch_size=self.SNAPSHOT_BATCH_SIZE,
381389
):
382390
self.engine_adapter.delete_from(self.auto_restatements_table, where=where)
383391

384-
next_auto_restatement_ts_filtered = {
385-
k: v for k, v in next_auto_restatement_ts.items() if v is not None
386-
}
387392
if not next_auto_restatement_ts_filtered:
388393
return
389394

390-
self.engine_adapter.insert_append(
395+
self.engine_adapter.merge(
391396
self.auto_restatements_table,
392397
_auto_restatements_to_df(next_auto_restatement_ts_filtered),
393398
columns_to_types=self._auto_restatement_columns_to_types,
399+
unique_key=(exp.column("snapshot_name"), exp.column("snapshot_version")),
394400
)
395401

396402
def count(self) -> int:

tests/core/engine_adapter/integration/test_integration_postgres.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pytest
33
from pytest import FixtureRequest
44
from sqlmesh.core.engine_adapter import PostgresEngineAdapter
5-
from tests.core.engine_adapter.integration import TestContext
65

76
from tests.core.engine_adapter.integration import (
87
TestContext,
@@ -29,3 +28,8 @@ def engine_adapter(ctx: TestContext) -> PostgresEngineAdapter:
2928
def test_engine_adapter(ctx: TestContext):
3029
assert isinstance(ctx.engine_adapter, PostgresEngineAdapter)
3130
assert ctx.engine_adapter.fetchone("select 1") == (1,)
31+
32+
33+
def test_server_version_psycopg(ctx: TestContext):
34+
assert isinstance(ctx.engine_adapter, PostgresEngineAdapter)
35+
assert ctx.engine_adapter.server_version != (0, 0)

tests/core/engine_adapter/test_postgres.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_create_table_like(make_mocked_engine_adapter: t.Callable):
9494

9595
def test_merge_version_gte_15(make_mocked_engine_adapter: t.Callable):
9696
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
97-
adapter._connection_pool.get().server_version = 150000
97+
adapter.server_version = (15, 0)
9898

9999
adapter.merge(
100100
target_table="target",
@@ -117,7 +117,7 @@ def test_merge_version_lt_15(
117117
make_mocked_engine_adapter: t.Callable, make_temp_table_name: t.Callable, mocker: MockerFixture
118118
):
119119
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
120-
adapter._connection_pool.get().server_version = 140000
120+
adapter.server_version = (14, 0)
121121

122122
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
123123
table_name = "test"
@@ -161,3 +161,19 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]:
161161
assert to_sql_calls(adapter) == [
162162
'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE',
163163
]
164+
165+
166+
def test_server_version(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture):
167+
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
168+
169+
fetchone_mock = mocker.patch.object(adapter, "fetchone")
170+
fetchone_mock.return_value = ("14.0",)
171+
assert adapter.server_version == (14, 0)
172+
173+
del adapter.server_version
174+
fetchone_mock.return_value = ("15.8",)
175+
assert adapter.server_version == (15, 8)
176+
177+
del adapter.server_version
178+
fetchone_mock.return_value = ("15.13 (Debian 15.13-1.pgdg120+1)",)
179+
assert adapter.server_version == (15, 13)

0 commit comments

Comments
 (0)