From 329b2f2dc383f8970bc965daa65ea879e33626cc Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Mon, 15 Sep 2025 12:13:34 -0700 Subject: [PATCH] feat: use merge for delta and iceberg trino catalogs --- sqlmesh/core/engine_adapter/base.py | 24 +++++++ sqlmesh/core/engine_adapter/bigquery.py | 5 +- sqlmesh/core/engine_adapter/fabric.py | 30 +-------- sqlmesh/core/engine_adapter/mixins.py | 47 ------------- sqlmesh/core/engine_adapter/mssql.py | 3 +- sqlmesh/core/engine_adapter/shared.py | 6 ++ sqlmesh/core/engine_adapter/trino.py | 39 ++++++++--- tests/core/engine_adapter/test_base.py | 6 +- tests/core/engine_adapter/test_mssql.py | 41 ------------ tests/core/engine_adapter/test_trino.py | 89 +++++++++++++++++++++++++ 10 files changed, 155 insertions(+), 135 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index c48ce2154d..94900f0193 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1633,6 +1633,30 @@ def _insert_overwrite_by_condition( target_columns_to_types=target_columns_to_types, order_projections=False, ) + elif insert_overwrite_strategy.is_merge: + columns = [exp.column(col) for col in target_columns_to_types] + when_not_matched_by_source = exp.When( + matched=False, + source=True, + condition=where, + then=exp.Delete(), + ) + when_not_matched_by_target = exp.When( + matched=False, + source=False, + then=exp.Insert( + this=exp.Tuple(expressions=columns), + expression=exp.Tuple(expressions=columns), + ), + ) + self._merge( + target_table=table_name, + query=query, + on=exp.false(), + whens=exp.Whens( + expressions=[when_not_matched_by_source, when_not_matched_by_target] + ), + ) else: insert_exp = exp.insert( query, diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index b3d02d8bbf..00b33f67a5 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -9,7 +9,6 @@ from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.mixins import ( - InsertOverwriteWithMergeMixin, ClusteredByMixin, RowDiffMixin, TableAlterClusterByOperation, @@ -20,6 +19,7 @@ DataObjectType, SourceQuery, set_catalog, + InsertOverwriteStrategy, ) from sqlmesh.core.node import IntervalUnit from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport @@ -54,7 +54,7 @@ @set_catalog() -class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, RowDiffMixin): +class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin): """ BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API. """ @@ -68,6 +68,7 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin, ClusteredByMixin, Row MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTS_QUERY_EXECUTION_TRACKING = True SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE SCHEMA_DIFFER_KWARGS = { "compatible_types": { diff --git a/sqlmesh/core/engine_adapter/fabric.py b/sqlmesh/core/engine_adapter/fabric.py index a528be3cb4..773e41a4b3 100644 --- a/sqlmesh/core/engine_adapter/fabric.py +++ b/sqlmesh/core/engine_adapter/fabric.py @@ -7,22 +7,14 @@ from functools import cached_property from sqlglot import exp from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result +from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter from sqlmesh.core.engine_adapter.shared import ( InsertOverwriteStrategy, - SourceQuery, ) -from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.connection_pool import ConnectionPool - -if t.TYPE_CHECKING: - from sqlmesh.core._typing import TableName - - -from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin - logger = logging.getLogger(__name__) @@ -58,26 +50,6 @@ def _target_catalog(self) -> t.Optional[str]: def _target_catalog(self, value: t.Optional[str]) -> None: self._connection_pool.set_attribute("target_catalog", value) - def _insert_overwrite_by_condition( - self, - table_name: TableName, - source_queries: t.List[SourceQuery], - target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - where: t.Optional[exp.Condition] = None, - insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, - **kwargs: t.Any, - ) -> None: - # Override to avoid MERGE statement which isn't fully supported in Fabric - return EngineAdapter._insert_overwrite_by_condition( - self, - table_name=table_name, - source_queries=source_queries, - target_columns_to_types=target_columns_to_types, - where=where, - insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, - **kwargs, - ) - @property def api_client(self) -> FabricHttpClient: # the requests Session is not guaranteed to be threadsafe diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 865e47fb93..1d66da0607 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -9,7 +9,6 @@ from sqlglot.helper import seq_get from sqlmesh.core.engine_adapter.base import EngineAdapter -from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery from sqlmesh.core.node import IntervalUnit from sqlmesh.core.dialect import schema_ from sqlmesh.core.schema_diff import TableAlterOperation @@ -75,52 +74,6 @@ def _fetch_native_df( return df -class InsertOverwriteWithMergeMixin(EngineAdapter): - def _insert_overwrite_by_condition( - self, - table_name: TableName, - source_queries: t.List[SourceQuery], - target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - where: t.Optional[exp.Condition] = None, - insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, - **kwargs: t.Any, - ) -> None: - """ - Some engines do not support `INSERT OVERWRITE` but instead support - doing an "INSERT OVERWRITE" using a Merge expression but with the - predicate being `False`. - """ - target_columns_to_types = target_columns_to_types or self.columns(table_name) - for source_query in source_queries: - with source_query as query: - query = self._order_projections_and_filter( - query, target_columns_to_types, where=where - ) - columns = [exp.column(col) for col in target_columns_to_types] - when_not_matched_by_source = exp.When( - matched=False, - source=True, - condition=where, - then=exp.Delete(), - ) - when_not_matched_by_target = exp.When( - matched=False, - source=False, - then=exp.Insert( - this=exp.Tuple(expressions=columns), - expression=exp.Tuple(expressions=columns), - ), - ) - self._merge( - target_table=table_name, - query=query, - on=exp.false(), - whens=exp.Whens( - expressions=[when_not_matched_by_source, when_not_matched_by_target] - ), - ) - - class HiveMetastoreTablePropertiesMixin(EngineAdapter): MAX_TABLE_COMMENT_LENGTH = 4000 MAX_COLUMN_COMMENT_LENGTH = 4000 diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index 50a67b4b37..fd0bf1011b 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -16,7 +16,6 @@ ) from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, - InsertOverwriteWithMergeMixin, PandasNativeFetchDFSupportMixin, VarcharSizeWorkaroundMixin, RowDiffMixin, @@ -41,7 +40,6 @@ class MSSQLEngineAdapter( EngineAdapterWithIndexSupport, PandasNativeFetchDFSupportMixin, - InsertOverwriteWithMergeMixin, GetCurrentCatalogFromFunctionMixin, VarcharSizeWorkaroundMixin, RowDiffMixin, @@ -74,6 +72,7 @@ class MSSQLEngineAdapter( }, } VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"} + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE @property def catalog_support(self) -> CatalogSupport: diff --git a/sqlmesh/core/engine_adapter/shared.py b/sqlmesh/core/engine_adapter/shared.py index 55f04a995e..ba0e1fa619 100644 --- a/sqlmesh/core/engine_adapter/shared.py +++ b/sqlmesh/core/engine_adapter/shared.py @@ -243,6 +243,8 @@ class InsertOverwriteStrategy(Enum): # Issue a single INSERT query to replace a data range. The assumption is that the query engine will transparently match partition bounds # and replace data rather than append to it. Trino is an example of this when `hive.insert-existing-partitions-behavior=OVERWRITE` is configured INTO_IS_OVERWRITE = 4 + # Do the INSERT OVERWRITE using merge since the engine doesn't support it natively + MERGE = 5 @property def is_delete_insert(self) -> bool: @@ -260,6 +262,10 @@ def is_replace_where(self) -> bool: def is_into_is_overwrite(self) -> bool: return self == InsertOverwriteStrategy.INTO_IS_OVERWRITE + @property + def is_merge(self) -> bool: + return self == InsertOverwriteStrategy.MERGE + class SourceQuery: def __init__( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 0e6853dd4a..67f2efd340 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -34,8 +34,6 @@ from sqlmesh.core._typing import SchemaName, SessionProperties, TableName from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF -CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {"iceberg", "delta_lake"} - @set_catalog() class TrinoEngineAdapter( @@ -117,6 +115,22 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]: finally: self.execute(f"RESET SESSION AUTHORIZATION") + @classmethod + def _is_hive_catalog(cls, catalog_type: str) -> bool: + return catalog_type == "hive" + + @classmethod + def _is_delta_lake_catalog(cls, catalog_type: str) -> bool: + # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name + # Ex: `acme_delta_lake` would be identified as an delta lake catalog + return "delta_lake" in catalog_type + + @classmethod + def _is_iceberg_catalog(cls, catalog_type: str) -> bool: + # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name + # Ex: `acme_iceberg` would be identified as an iceberg catalog + return "iceberg" in catalog_type + def replace_query( self, table_name: TableName, @@ -129,13 +143,9 @@ def replace_query( **kwargs: t.Any, ) -> None: catalog_type = self.get_catalog_type_from_table(table_name) - # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name - # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table - supports_replace_table_override = None - for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE: - if replace_table_catalog_type in catalog_type: - supports_replace_table_override = True - break + supports_replace_table_override = self._is_delta_lake_catalog( + catalog_type + ) or self._is_iceberg_catalog(catalog_type) super().replace_query( table_name=table_name, @@ -158,8 +168,9 @@ def _insert_overwrite_by_condition( **kwargs: t.Any, ) -> None: catalog = exp.to_table(table_name).catalog or self.get_current_catalog() + catalog_type = self.get_catalog_type(catalog) - if where and self.get_catalog_type(catalog) == "hive": + if where and self._is_hive_catalog(catalog_type): # These session properties are only valid for the Trino Hive connector # Attempting to set them on an Iceberg catalog will throw an error: # "Session property 'catalog.insert_existing_partitions_behavior' does not exist" @@ -168,6 +179,14 @@ def _insert_overwrite_by_condition( table_name, source_queries, target_columns_to_types, where ) self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'") + elif self._is_delta_lake_catalog(catalog_type) or self._is_iceberg_catalog(catalog_type): + super()._insert_overwrite_by_condition( + table_name, + source_queries, + target_columns_to_types, + where, + insert_overwrite_strategy_override=InsertOverwriteStrategy.MERGE, + ) else: super()._insert_overwrite_by_condition( table_name, diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index b2dfcc7ccc..220c3291f7 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -13,7 +13,6 @@ from sqlmesh.core import dialect as d from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.engine_adapter import EngineAdapter, EngineAdapterWithIndexSupport -from sqlmesh.core.engine_adapter.mixins import InsertOverwriteWithMergeMixin from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObject from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation, NestedSupport from sqlmesh.utils import columns_to_types_to_struct @@ -21,8 +20,6 @@ from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError from tests.core.engine_adapter import to_sql_calls -if t.TYPE_CHECKING: - pass pytestmark = pytest.mark.engine @@ -482,7 +479,8 @@ def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable): def test_insert_overwrite_by_condition_column_contains_unsafe_characters( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): - adapter = make_mocked_engine_adapter(InsertOverwriteWithMergeMixin) + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( parse_one("SELECT 1 AS c"), None, target_table="test_table" diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py index 5923afa217..a405bb7576 100644 --- a/tests/core/engine_adapter/test_mssql.py +++ b/tests/core/engine_adapter/test_mssql.py @@ -16,7 +16,6 @@ from sqlmesh.core.engine_adapter.shared import ( DataObject, DataObjectType, - InsertOverwriteStrategy, ) from sqlmesh.utils.date import to_ds from tests.core.engine_adapter import to_sql_calls @@ -342,46 +341,6 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_exi ] -def test_insert_overwrite_by_time_partition_replace_where_pandas( - make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable -): - mocker.patch( - "sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists", - return_value=False, - ) - - adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) - adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE - - temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") - table_name = "test_table" - temp_table_id = "abcdefgh" - temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) - - df = pd.DataFrame({"a": [1, 2], "ds": ["2022-01-01", "2022-01-02"]}) - adapter.insert_overwrite_by_time_partition( - table_name, - df, - start="2022-01-01", - end="2022-01-02", - time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - time_column="ds", - target_columns_to_types={ - "a": exp.DataType.build("INT"), - "ds": exp.DataType.build("STRING"), - }, - ) - adapter._connection_pool.get().bulk_copy.assert_called_with( - f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")] - ) - - assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [ds] VARCHAR(MAX))');""", - f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""", - f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];", - ] - - def test_insert_append_pandas( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable ): diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index 07c4657eb3..c3a54e39d6 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -11,6 +11,7 @@ from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import SqlModel from sqlmesh.core.dialect import schema_ +from sqlmesh.utils.date import to_ds from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls @@ -683,3 +684,91 @@ def test_replace_table_catalog_support( sql_calls[0] == f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"' ) + + +def test_insert_overwrite_time_partition_hive( + trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture +): + adapter = trino_mocked_engine_adapter + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["hive", "schema", "test_table"]), + query_or_df=parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + "SET SESSION hive.insert_existing_partitions_behavior='OVERWRITE'", + 'INSERT INTO "hive"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + "SET SESSION hive.insert_existing_partitions_behavior='APPEND'", + ] + + +def test_insert_overwrite_time_partition_iceberg( + trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture +): + adapter = trino_mocked_engine_adapter + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["acme_iceberg", "schema", "test_table"]), + query_or_df=parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + 'MERGE INTO "acme_iceberg"."schema"."test_table" AS "__merge_target__" USING (SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\') AS "__MERGE_SOURCE__" ON FALSE WHEN NOT MATCHED BY SOURCE AND "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\' THEN DELETE WHEN NOT MATCHED THEN INSERT ("a", "b") VALUES ("a", "b")' + ] + + +def test_insert_overwrite_time_partition_delta( + trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture +): + adapter = trino_mocked_engine_adapter + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["acme_delta_lake", "schema", "test_table"]), + query_or_df=parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + 'MERGE INTO "acme_delta_lake"."schema"."test_table" AS "__merge_target__" USING (SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\') AS "__MERGE_SOURCE__" ON FALSE WHEN NOT MATCHED BY SOURCE AND "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\' THEN DELETE WHEN NOT MATCHED THEN INSERT ("a", "b") VALUES ("a", "b")' + ] + + +def test_insert_overwrite_time_partition_other( + trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture +): + adapter = trino_mocked_engine_adapter + + mocker.patch( + "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_catalog_type", + return_value="other", + ) + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["other", "schema", "test_table"]), + query_or_df=parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "other"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + 'INSERT INTO "other"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + ]