From 4744324493645f4e65a2c3321ba362ca00d95c1b Mon Sep 17 00:00:00 2001 From: eakmanrq <6326532+eakmanrq@users.noreply.github.com> Date: Tue, 2 Sep 2025 09:43:02 -0700 Subject: [PATCH] feat: trino support replace table iceberg and delta --- sqlmesh/core/engine_adapter/athena.py | 1 + sqlmesh/core/engine_adapter/base.py | 18 +++++++- sqlmesh/core/engine_adapter/mixins.py | 2 +- sqlmesh/core/engine_adapter/redshift.py | 1 + sqlmesh/core/engine_adapter/trino.py | 41 ++++++++++++++++-- tests/core/engine_adapter/test_trino.py | 55 +++++++++++++++++++++---- 6 files changed, 104 insertions(+), 14 deletions(-) diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index aa8a5ce0c1..bd84ba5276 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -437,6 +437,7 @@ def replace_query( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 2901831940..c48ce2154d 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -432,8 +432,16 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str: ) return self.DEFAULT_CATALOG_TYPE + def get_catalog_type_from_table(self, table: TableName) -> str: + """Get the catalog type from a table name if it has a catalog specified, otherwise return the current catalog type""" + catalog = exp.to_table(table).catalog or self.get_current_catalog() + return self.get_catalog_type(catalog) + @property def current_catalog_type(self) -> str: + # `get_catalog_type_from_table` should be used over this property. Reason is that the table that is the target + # of the operation is what matters and not the catalog type of the connection. + # This still remains for legacy reasons and should be refactored out. return self.get_catalog_type(self.get_current_catalog()) def replace_query( @@ -444,6 +452,7 @@ def replace_query( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, **kwargs: t.Any, ) -> None: """Replaces an existing table with a query. @@ -494,12 +503,17 @@ def replace_query( ) # All engines support `CREATE TABLE AS` so we use that if the table doesn't already exist and we # use `CREATE OR REPLACE TABLE AS` if the engine supports it - if self.SUPPORTS_REPLACE_TABLE or not table_exists: + supports_replace_table = ( + self.SUPPORTS_REPLACE_TABLE + if supports_replace_table_override is None + else supports_replace_table_override + ) + if supports_replace_table or not table_exists: return self._create_table_from_source_queries( target_table, source_queries, target_columns_to_types, - replace=self.SUPPORTS_REPLACE_TABLE, + replace=supports_replace_table, table_description=table_description, column_descriptions=column_descriptions, **kwargs, diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index bc83beb3d4..865e47fb93 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -228,7 +228,7 @@ def _build_view_properties_exp( def _truncate_comment(self, comment: str, length: t.Optional[int]) -> str: # iceberg and delta do not have a comment length limit - if self.current_catalog_type in ("iceberg", "delta"): + if self.current_catalog_type in ("iceberg", "delta_lake"): return comment return super()._truncate_comment(comment, length) diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 7d14207b52..7979268473 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -253,6 +253,7 @@ def replace_query( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, **kwargs: t.Any, ) -> None: """ diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 4cef557d94..90b3da5240 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -34,6 +34,8 @@ from sqlmesh.core._typing import SchemaName, SessionProperties, TableName from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF +CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {"iceberg", "delta_lake"} + @set_catalog() class TrinoEngineAdapter( @@ -115,6 +117,37 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]: finally: self.execute(f"RESET SESSION AUTHORIZATION") + def replace_query( + self, + table_name: TableName, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, + **kwargs: t.Any, + ) -> None: + catalog_type = self.get_catalog_type(self.get_catalog_type_from_table(table_name)) + # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name + # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table + supports_replace_table_override = None + for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE: + if replace_table_catalog_type in catalog_type: + supports_replace_table_override = True + break + + super().replace_query( + table_name=table_name, + query_or_df=query_or_df, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + source_columns=source_columns, + supports_replace_table_override=supports_replace_table_override, + **kwargs, + ) + def _insert_overwrite_by_condition( self, table_name: TableName, @@ -250,7 +283,7 @@ def _build_schema_exp( expressions: t.Optional[t.List[exp.PrimaryKey]] = None, is_view: bool = False, ) -> exp.Schema: - if self.current_catalog_type == "delta_lake": + if "delta_lake" in self.get_catalog_type_from_table(table): target_columns_to_types = self._to_delta_ts(target_columns_to_types) return super()._build_schema_exp( @@ -277,7 +310,9 @@ def _scd_type_2( source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: - if target_columns_to_types and self.current_catalog_type == "delta_lake": + if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table( + target_table + ): target_columns_to_types = self._to_delta_ts(target_columns_to_types) return super()._scd_type_2( @@ -381,7 +416,7 @@ def _create_table( else: table_name = table_name_or_schema - if self.current_catalog_type == "hive": + if "hive" in self.get_catalog_type_from_table(table_name): # the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads # (even if metadata TTL is set to 0s) # Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index 745c2bbdfb..930834c8d1 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -24,8 +24,8 @@ def trino_mocked_engine_adapter( def mock_catalog_type(catalog_name): if "iceberg" in catalog_name: return "iceberg" - if "delta" in catalog_name: - return "delta" + if "delta_lake" in catalog_name: + return "delta_lake" return "hive" mocker.patch( @@ -50,7 +50,7 @@ def test_set_current_catalog(trino_mocked_engine_adapter: TrinoEngineAdapter): ] -@pytest.mark.parametrize("storage_type", ["iceberg", "delta"]) +@pytest.mark.parametrize("storage_type", ["iceberg", "delta_lake"]) def test_get_catalog_type( trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture, storage_type: str ): @@ -64,13 +64,14 @@ def test_get_catalog_type( assert adapter.get_catalog_type("foo") == TrinoEngineAdapter.DEFAULT_CATALOG_TYPE assert adapter.get_catalog_type("datalake_hive") == "hive" assert adapter.get_catalog_type("datalake_iceberg") == "iceberg" - assert adapter.get_catalog_type("datalake_delta") == "delta" + assert adapter.get_catalog_type("datalake_delta_lake") == "delta_lake" mocker.patch( "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_current_catalog", return_value=f"system_{storage_type}", ) - assert adapter.current_catalog_type == storage_type + expected_current_type = storage_type + assert adapter.current_catalog_type == expected_current_type def test_get_catalog_type_cached( @@ -103,7 +104,7 @@ def mock_fetchone(sql): assert fetchone_mock.call_count == 2 -@pytest.mark.parametrize("storage_type", ["hive", "delta"]) +@pytest.mark.parametrize("storage_type", ["hive", "delta_lake"]) def test_partitioned_by_hive_delta( trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture, storage_type: str ): @@ -113,7 +114,8 @@ def test_partitioned_by_hive_delta( "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_current_catalog", return_value=f"datalake_{storage_type}", ) - assert adapter.get_catalog_type(f"datalake_{storage_type}") == storage_type + expected_type = storage_type + assert adapter.get_catalog_type(f"datalake_{storage_type}") == expected_type columns_to_types = { "cola": exp.DataType.build("INT"), @@ -314,7 +316,7 @@ def test_comments_hive(mocker: MockerFixture, make_mocked_engine_adapter: t.Call ] -@pytest.mark.parametrize("storage_type", ["iceberg", "delta"]) +@pytest.mark.parametrize("storage_type", ["iceberg", "delta_lake"]) def test_comments_iceberg_delta( mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, storage_type: str ): @@ -646,3 +648,40 @@ def test_session_authorization(trino_mocked_engine_adapter: TrinoEngineAdapter): "SELECT 1", "RESET SESSION AUTHORIZATION", ] + + +@pytest.mark.parametrize( + "catalog_name,expected_replace", + [ + ("hive_catalog", False), + ("iceberg_catalog", True), + ("delta_catalog", False), + ("acme_delta_lake", True), + ("acme_iceberg", True), + ("custom_delta_lake_something", True), + ("my_iceberg_store", True), + ("plain_catalog", False), + ], +) +def test_replace_table_catalog_support( + trino_mocked_engine_adapter: TrinoEngineAdapter, catalog_name, expected_replace +): + adapter = trino_mocked_engine_adapter + + adapter.replace_query( + table_name=".".join([catalog_name, "schema", "test_table"]), + query_or_df=parse_one("SELECT 1 AS col"), + ) + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 1 + if expected_replace: + assert ( + sql_calls[0] + == f'CREATE OR REPLACE TABLE "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"' + ) + else: + assert ( + sql_calls[0] + == f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"' + )