Skip to content

Commit 9fd4e8e

Browse files
committed
chore: refactor to remove merge mixin
1 parent 906be74 commit 9fd4e8e

File tree

8 files changed

+37
-124
lines changed

8 files changed

+37
-124
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,30 @@ def _insert_overwrite_by_condition(
16331633
target_columns_to_types=target_columns_to_types,
16341634
order_projections=False,
16351635
)
1636+
elif insert_overwrite_strategy.is_merge:
1637+
columns = [exp.column(col) for col in target_columns_to_types]
1638+
when_not_matched_by_source = exp.When(
1639+
matched=False,
1640+
source=True,
1641+
condition=where,
1642+
then=exp.Delete(),
1643+
)
1644+
when_not_matched_by_target = exp.When(
1645+
matched=False,
1646+
source=False,
1647+
then=exp.Insert(
1648+
this=exp.Tuple(expressions=columns),
1649+
expression=exp.Tuple(expressions=columns),
1650+
),
1651+
)
1652+
self._merge(
1653+
target_table=table_name,
1654+
query=query,
1655+
on=exp.false(),
1656+
whens=exp.Whens(
1657+
expressions=[when_not_matched_by_source, when_not_matched_by_target]
1658+
),
1659+
)
16361660
else:
16371661
insert_exp = exp.insert(
16381662
query,

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from sqlmesh.core.dialect import to_schema
1111
from sqlmesh.core.engine_adapter.mixins import (
12-
InsertOverwriteWithMergeMixin,
1312
ClusteredByMixin,
1413
RowDiffMixin,
1514
TableAlterClusterByOperation,
@@ -20,6 +19,7 @@
2019
DataObjectType,
2120
SourceQuery,
2221
set_catalog,
22+
InsertOverwriteStrategy,
2323
)
2424
from sqlmesh.core.node import IntervalUnit
2525
from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport
@@ -54,7 +54,7 @@
5454

5555

5656
@set_catalog()
57-
class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, RowDiffMixin):
57+
class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin):
5858
"""
5959
BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API.
6060
"""
@@ -68,6 +68,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row
6868
MAX_COLUMN_COMMENT_LENGTH = 1024
6969
SUPPORTS_QUERY_EXECUTION_TRACKING = True
7070
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]
71+
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE
7172

7273
SCHEMA_DIFFER_KWARGS = {
7374
"compatible_types": {

sqlmesh/core/engine_adapter/fabric.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,15 @@
77
from functools import cached_property
88
from sqlglot import exp
99
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result
10+
from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin
1011
from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter
1112
from sqlmesh.core.engine_adapter.shared import (
1213
InsertOverwriteStrategy,
13-
SourceQuery,
1414
)
15-
from sqlmesh.core.engine_adapter.base import EngineAdapter
1615
from sqlmesh.utils.errors import SQLMeshError
1716
from sqlmesh.utils.connection_pool import ConnectionPool
1817

1918

20-
if t.TYPE_CHECKING:
21-
from sqlmesh.core._typing import TableName
22-
23-
24-
from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin
25-
2619
logger = logging.getLogger(__name__)
2720

2821

@@ -58,26 +51,6 @@ def _target_catalog(self) -> t.Optional[str]:
5851
def _target_catalog(self, value: t.Optional[str]) -> None:
5952
self._connection_pool.set_attribute("target_catalog", value)
6053

61-
def _insert_overwrite_by_condition(
62-
self,
63-
table_name: TableName,
64-
source_queries: t.List[SourceQuery],
65-
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
66-
where: t.Optional[exp.Condition] = None,
67-
insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
68-
**kwargs: t.Any,
69-
) -> None:
70-
# Override to avoid MERGE statement which isn't fully supported in Fabric
71-
return EngineAdapter._insert_overwrite_by_condition(
72-
self,
73-
table_name=table_name,
74-
source_queries=source_queries,
75-
target_columns_to_types=target_columns_to_types,
76-
where=where,
77-
insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT,
78-
**kwargs,
79-
)
80-
8154
@property
8255
def api_client(self) -> FabricHttpClient:
8356
# the requests Session is not guaranteed to be threadsafe

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sqlglot.helper import seq_get
1010

1111
from sqlmesh.core.engine_adapter.base import EngineAdapter
12-
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery
1312
from sqlmesh.core.node import IntervalUnit
1413
from sqlmesh.core.dialect import schema_
1514
from sqlmesh.core.schema_diff import TableAlterOperation
@@ -75,52 +74,6 @@ def _fetch_native_df(
7574
return df
7675

7776

78-
class InsertOverwriteWithMergeMixin(EngineAdapter):
79-
def _insert_overwrite_by_condition(
80-
self,
81-
table_name: TableName,
82-
source_queries: t.List[SourceQuery],
83-
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
84-
where: t.Optional[exp.Condition] = None,
85-
insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
86-
**kwargs: t.Any,
87-
) -> None:
88-
"""
89-
Some engines do not support `INSERT OVERWRITE` but instead support
90-
doing an "INSERT OVERWRITE" using a Merge expression but with the
91-
predicate being `False`.
92-
"""
93-
target_columns_to_types = target_columns_to_types or self.columns(table_name)
94-
for source_query in source_queries:
95-
with source_query as query:
96-
query = self._order_projections_and_filter(
97-
query, target_columns_to_types, where=where
98-
)
99-
columns = [exp.column(col) for col in target_columns_to_types]
100-
when_not_matched_by_source = exp.When(
101-
matched=False,
102-
source=True,
103-
condition=where,
104-
then=exp.Delete(),
105-
)
106-
when_not_matched_by_target = exp.When(
107-
matched=False,
108-
source=False,
109-
then=exp.Insert(
110-
this=exp.Tuple(expressions=columns),
111-
expression=exp.Tuple(expressions=columns),
112-
),
113-
)
114-
self._merge(
115-
target_table=table_name,
116-
query=query,
117-
on=exp.false(),
118-
whens=exp.Whens(
119-
expressions=[when_not_matched_by_source, when_not_matched_by_target]
120-
),
121-
)
122-
123-
12477
class HiveMetastoreTablePropertiesMixin(EngineAdapter):
12578
MAX_TABLE_COMMENT_LENGTH = 4000
12679
MAX_COLUMN_COMMENT_LENGTH = 4000

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from sqlmesh.core.engine_adapter.mixins import (
1818
GetCurrentCatalogFromFunctionMixin,
19-
InsertOverwriteWithMergeMixin,
2019
PandasNativeFetchDFSupportMixin,
2120
VarcharSizeWorkaroundMixin,
2221
RowDiffMixin,
@@ -41,7 +40,6 @@
4140
class MSSQLEngineAdapter(
4241
EngineAdapterWithIndexSupport,
4342
PandasNativeFetchDFSupportMixin,
44-
InsertOverwriteWithMergeMixin,
4543
GetCurrentCatalogFromFunctionMixin,
4644
VarcharSizeWorkaroundMixin,
4745
RowDiffMixin,
@@ -74,6 +72,7 @@ class MSSQLEngineAdapter(
7472
},
7573
}
7674
VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"}
75+
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE
7776

7877
@property
7978
def catalog_support(self) -> CatalogSupport:

sqlmesh/core/engine_adapter/shared.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ class InsertOverwriteStrategy(Enum):
243243
# Issue a single INSERT query to replace a data range. The assumption is that the query engine will transparently match partition bounds
244244
# and replace data rather than append to it. Trino is an example of this when `hive.insert-existing-partitions-behavior=OVERWRITE` is configured
245245
INTO_IS_OVERWRITE = 4
246+
# Do the INSERT OVERWRITE using merge since the engine doesn't support it natively
247+
MERGE = 5
246248

247249
@property
248250
def is_delete_insert(self) -> bool:
@@ -260,6 +262,10 @@ def is_replace_where(self) -> bool:
260262
def is_into_is_overwrite(self) -> bool:
261263
return self == InsertOverwriteStrategy.INTO_IS_OVERWRITE
262264

265+
@property
266+
def is_merge(self) -> bool:
267+
return self == InsertOverwriteStrategy.MERGE
268+
263269

264270
class SourceQuery:
265271
def __init__(

tests/core/engine_adapter/test_base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313
from sqlmesh.core import dialect as d
1414
from sqlmesh.core.dialect import normalize_model_name
1515
from sqlmesh.core.engine_adapter import EngineAdapter, EngineAdapterWithIndexSupport
16-
from sqlmesh.core.engine_adapter.mixins import InsertOverwriteWithMergeMixin
1716
from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObject
1817
from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation, NestedSupport
1918
from sqlmesh.utils import columns_to_types_to_struct
2019
from sqlmesh.utils.date import to_ds
2120
from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError
2221
from tests.core.engine_adapter import to_sql_calls
2322

24-
if t.TYPE_CHECKING:
25-
pass
2623

2724
pytestmark = pytest.mark.engine
2825

@@ -482,7 +479,8 @@ def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable):
482479
def test_insert_overwrite_by_condition_column_contains_unsafe_characters(
483480
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture
484481
):
485-
adapter = make_mocked_engine_adapter(InsertOverwriteWithMergeMixin)
482+
adapter = make_mocked_engine_adapter(EngineAdapter)
483+
adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE
486484

487485
source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types(
488486
parse_one("SELECT 1 AS c"), None, target_table="test_table"

tests/core/engine_adapter/test_mssql.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from sqlmesh.core.engine_adapter.shared import (
1717
DataObject,
1818
DataObjectType,
19-
InsertOverwriteStrategy,
2019
)
2120
from sqlmesh.utils.date import to_ds
2221
from tests.core.engine_adapter import to_sql_calls
@@ -342,46 +341,6 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_exi
342341
]
343342

344343

345-
def test_insert_overwrite_by_time_partition_replace_where_pandas(
346-
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
347-
):
348-
mocker.patch(
349-
"sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists",
350-
return_value=False,
351-
)
352-
353-
adapter = make_mocked_engine_adapter(MSSQLEngineAdapter)
354-
adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE
355-
356-
temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table")
357-
table_name = "test_table"
358-
temp_table_id = "abcdefgh"
359-
temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id)
360-
361-
df = pd.DataFrame({"a": [1, 2], "ds": ["2022-01-01", "2022-01-02"]})
362-
adapter.insert_overwrite_by_time_partition(
363-
table_name,
364-
df,
365-
start="2022-01-01",
366-
end="2022-01-02",
367-
time_formatter=lambda x, _: exp.Literal.string(to_ds(x)),
368-
time_column="ds",
369-
target_columns_to_types={
370-
"a": exp.DataType.build("INT"),
371-
"ds": exp.DataType.build("STRING"),
372-
},
373-
)
374-
adapter._connection_pool.get().bulk_copy.assert_called_with(
375-
f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")]
376-
)
377-
378-
assert to_sql_calls(adapter) == [
379-
f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [ds] VARCHAR(MAX))');""",
380-
f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""",
381-
f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];",
382-
]
383-
384-
385344
def test_insert_append_pandas(
386345
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
387346
):

0 commit comments

Comments
 (0)