Skip to content

Commit d264a76

Browse files
committed
Feat: Re-introduce merge for updating auto restatements
1 parent 29d7db7 commit d264a76

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

sqlmesh/core/engine_adapter/postgres.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import re
45
import typing as t
56
from functools import partial
67
from sqlglot import exp
@@ -112,11 +113,8 @@ def merge(
112113
**kwargs: t.Any,
113114
) -> None:
114115
# Merge isn't supported until Postgres 15
115-
merge_impl = (
116-
super().merge
117-
if self._connection_pool.get().server_version >= 150000
118-
else partial(logical_merge, self)
119-
)
116+
major, minor = self.get_server_version()
117+
merge_impl = super().merge if major >= 15 else partial(logical_merge, self)
120118
merge_impl( # type: ignore
121119
target_table,
122120
source_table,
@@ -125,3 +123,23 @@ def merge(
125123
when_matched=when_matched,
126124
merge_filter=merge_filter,
127125
)
126+
127+
def get_server_version(self) -> t.Tuple[int, int]:
128+
"""Return major and minor server versions of the connection"""
129+
connection = self._connection_pool.get()
130+
connection_module = connection.__class__.__module__
131+
if connection_module.startswith("pg8000"):
132+
server_version = connection.parameter_statuses.get("server_version")
133+
# pg8000 server version contains version as well as packaging and distribution information
134+
# e.g. 15.13 (Debian 15.13-1.pgdg120+1)
135+
match = re.search(r"(\d+)\.(\d+)", server_version)
136+
if match:
137+
return int(match.group(1)), int(match.group(2))
138+
elif connection_module.startswith("psycopg"):
139+
# This handles both psycopg and psycopg2 connection objects
140+
server_version = connection.info.server_version
141+
# Since major version 10, PostgreSQL represents the server version with an integer by
142+
# multiplying the server's major version number by 10000 and adding the minor version number
143+
# See https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION
144+
return server_version // 10000, server_version % 100
145+
return 0, 0

tests/core/engine_adapter/test_postgres.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def test_create_table_like(make_mocked_engine_adapter: t.Callable):
9494

9595
def test_merge_version_gte_15(make_mocked_engine_adapter: t.Callable):
9696
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
97-
adapter._connection_pool.get().server_version = 150000
97+
adapter._connection_pool.get().__class__.__module__ = "psycopg2.extensions"
98+
adapter._connection_pool.get().info.server_version = 150000
9899

99100
adapter.merge(
100101
target_table="target",
@@ -117,7 +118,8 @@ def test_merge_version_lt_15(
117118
make_mocked_engine_adapter: t.Callable, make_temp_table_name: t.Callable, mocker: MockerFixture
118119
):
119120
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
120-
adapter._connection_pool.get().server_version = 140000
121+
adapter._connection_pool.get().__class__.__module__ = "psycopg2.extensions"
122+
adapter._connection_pool.get().info.server_version = 140000
121123

122124
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
123125
table_name = "test"
@@ -161,3 +163,17 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]:
161163
assert to_sql_calls(adapter) == [
162164
'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE',
163165
]
166+
167+
168+
def test_get_server_version(make_mocked_engine_adapter: t.Callable):
169+
adapter = make_mocked_engine_adapter(PostgresEngineAdapter)
170+
171+
adapter._connection_pool.get().__class__.__module__ = "psycopg2.extensions"
172+
adapter._connection_pool.get().info.server_version = 150013
173+
assert adapter.get_server_version() == (15, 13)
174+
175+
adapter._connection_pool.get().__class__.__module__ = "pg8000.native"
176+
adapter._connection_pool.get().parameter_statuses = {
177+
"server_version": "15.13 (Debian 15.13-1.pgdg120+1)"
178+
}
179+
assert adapter.get_server_version() == (15, 13)

0 commit comments

Comments
 (0)