Skip to content

Commit 3ee5976

Browse files
committed
feat: trino support replace table iceberg and delta
1 parent 70c8deb commit 3ee5976

File tree

7 files changed

+105
-15
lines changed

7 files changed

+105
-15
lines changed

sqlmesh/core/engine_adapter/athena.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def replace_query(
437437
table_description: t.Optional[str] = None,
438438
column_descriptions: t.Optional[t.Dict[str, str]] = None,
439439
source_columns: t.Optional[t.List[str]] = None,
440+
supports_replace_table_override: t.Optional[bool] = None,
440441
**kwargs: t.Any,
441442
) -> None:
442443
table = exp.to_table(table_name)

sqlmesh/core/engine_adapter/base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,16 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str:
432432
)
433433
return self.DEFAULT_CATALOG_TYPE
434434

435+
def get_catalog_type_from_table(self, table: TableName) -> str:
436+
"""Get the catalog type from a table name if it has a catalog specified, otherwise return the current catalog type"""
437+
catalog = exp.to_table(table).catalog or self.get_current_catalog()
438+
return self.get_catalog_type(catalog)
439+
435440
@property
436441
def current_catalog_type(self) -> str:
442+
# `get_catalog_type_from_table` should be used over this property. Reason is that the table that is the target
443+
# of the operation is what matters and not the catalog type of the connection.
444+
# This still remains for legacy reasons and should be refactored out.
437445
return self.get_catalog_type(self.get_current_catalog())
438446

439447
def replace_query(
@@ -444,6 +452,7 @@ def replace_query(
444452
table_description: t.Optional[str] = None,
445453
column_descriptions: t.Optional[t.Dict[str, str]] = None,
446454
source_columns: t.Optional[t.List[str]] = None,
455+
supports_replace_table_override: t.Optional[bool] = None,
447456
**kwargs: t.Any,
448457
) -> None:
449458
"""Replaces an existing table with a query.
@@ -494,12 +503,17 @@ def replace_query(
494503
)
495504
# All engines support `CREATE TABLE AS` so we use that if the table doesn't already exist and we
496505
# use `CREATE OR REPLACE TABLE AS` if the engine supports it
497-
if self.SUPPORTS_REPLACE_TABLE or not table_exists:
506+
supports_replace_table = (
507+
self.SUPPORTS_REPLACE_TABLE
508+
if supports_replace_table_override is None
509+
else supports_replace_table_override
510+
)
511+
if supports_replace_table or not table_exists:
498512
return self._create_table_from_source_queries(
499513
target_table,
500514
source_queries,
501515
target_columns_to_types,
502-
replace=self.SUPPORTS_REPLACE_TABLE,
516+
replace=supports_replace_table,
503517
table_description=table_description,
504518
column_descriptions=column_descriptions,
505519
**kwargs,

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _build_view_properties_exp(
228228

229229
def _truncate_comment(self, comment: str, length: t.Optional[int]) -> str:
230230
# iceberg and delta do not have a comment length limit
231-
if self.current_catalog_type in ("iceberg", "delta"):
231+
if self.current_catalog_type in ("iceberg", "delta_lake"):
232232
return comment
233233
return super()._truncate_comment(comment, length)
234234

sqlmesh/core/engine_adapter/redshift.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def replace_query(
253253
table_description: t.Optional[str] = None,
254254
column_descriptions: t.Optional[t.Dict[str, str]] = None,
255255
source_columns: t.Optional[t.List[str]] = None,
256+
supports_replace_table_override: t.Optional[bool] = None,
256257
**kwargs: t.Any,
257258
) -> None:
258259
"""

sqlmesh/core/engine_adapter/trino.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
3535
from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF
3636

37+
CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {"iceberg", "delta_lake"}
38+
3739

3840
@set_catalog()
3941
class TrinoEngineAdapter(
@@ -115,6 +117,37 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]:
115117
finally:
116118
self.execute(f"RESET SESSION AUTHORIZATION")
117119

120+
def replace_query(
121+
self,
122+
table_name: TableName,
123+
query_or_df: QueryOrDF,
124+
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
125+
table_description: t.Optional[str] = None,
126+
column_descriptions: t.Optional[t.Dict[str, str]] = None,
127+
source_columns: t.Optional[t.List[str]] = None,
128+
supports_replace_table_override: t.Optional[bool] = None,
129+
**kwargs: t.Any,
130+
) -> None:
131+
catalog_type = self.get_catalog_type(self.get_catalog_type_from_table(table_name))
132+
# User may have a custom catalog type name so we are assuming they keep the catalog type still in the name
133+
# Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table
134+
supports_replace_table_override = None
135+
for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE:
136+
if replace_table_catalog_type in catalog_type:
137+
supports_replace_table_override = True
138+
break
139+
140+
super().replace_query(
141+
table_name=table_name,
142+
query_or_df=query_or_df,
143+
target_columns_to_types=target_columns_to_types,
144+
table_description=table_description,
145+
column_descriptions=column_descriptions,
146+
source_columns=source_columns,
147+
supports_replace_table_override=supports_replace_table_override,
148+
**kwargs,
149+
)
150+
118151
def _insert_overwrite_by_condition(
119152
self,
120153
table_name: TableName,
@@ -250,7 +283,7 @@ def _build_schema_exp(
250283
expressions: t.Optional[t.List[exp.PrimaryKey]] = None,
251284
is_view: bool = False,
252285
) -> exp.Schema:
253-
if self.current_catalog_type == "delta_lake":
286+
if "delta_lake" in self.get_catalog_type_from_table(table):
254287
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
255288

256289
return super()._build_schema_exp(
@@ -277,7 +310,9 @@ def _scd_type_2(
277310
source_columns: t.Optional[t.List[str]] = None,
278311
**kwargs: t.Any,
279312
) -> None:
280-
if target_columns_to_types and self.current_catalog_type == "delta_lake":
313+
if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
314+
target_table
315+
):
281316
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
282317

283318
return super()._scd_type_2(
@@ -381,7 +416,7 @@ def _create_table(
381416
else:
382417
table_name = table_name_or_schema
383418

384-
if self.current_catalog_type == "hive":
419+
if "hive" in self.get_catalog_type_from_table(table_name):
385420
# the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads
386421
# (even if metadata TTL is set to 0s)
387422
# Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail

tests/core/engine_adapter/integration/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def pytest_marks(self) -> t.List[MarkDecorator]:
7171
IntegrationTestEngine("postgres"),
7272
IntegrationTestEngine("mysql"),
7373
IntegrationTestEngine("mssql"),
74-
IntegrationTestEngine("trino", catalog_types=["hive", "iceberg", "delta", "nessie"]),
74+
IntegrationTestEngine("trino", catalog_types=["hive", "iceberg", "delta_lake", "nessie"]),
7575
IntegrationTestEngine("spark", native_dataframe_type="pyspark"),
7676
IntegrationTestEngine("clickhouse", catalog_types=["standalone", "cluster"]),
7777
IntegrationTestEngine("risingwave"),

tests/core/engine_adapter/test_trino.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def trino_mocked_engine_adapter(
2424
def mock_catalog_type(catalog_name):
2525
if "iceberg" in catalog_name:
2626
return "iceberg"
27-
if "delta" in catalog_name:
28-
return "delta"
27+
if "delta_lake" in catalog_name:
28+
return "delta_lake"
2929
return "hive"
3030

3131
mocker.patch(
@@ -50,7 +50,7 @@ def test_set_current_catalog(trino_mocked_engine_adapter: TrinoEngineAdapter):
5050
]
5151

5252

53-
@pytest.mark.parametrize("storage_type", ["iceberg", "delta"])
53+
@pytest.mark.parametrize("storage_type", ["iceberg", "delta_lake"])
5454
def test_get_catalog_type(
5555
trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture, storage_type: str
5656
):
@@ -64,13 +64,14 @@ def test_get_catalog_type(
6464
assert adapter.get_catalog_type("foo") == TrinoEngineAdapter.DEFAULT_CATALOG_TYPE
6565
assert adapter.get_catalog_type("datalake_hive") == "hive"
6666
assert adapter.get_catalog_type("datalake_iceberg") == "iceberg"
67-
assert adapter.get_catalog_type("datalake_delta") == "delta"
67+
assert adapter.get_catalog_type("datalake_delta_lake") == "delta_lake"
6868

6969
mocker.patch(
7070
"sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_current_catalog",
7171
return_value=f"system_{storage_type}",
7272
)
73-
assert adapter.current_catalog_type == storage_type
73+
expected_current_type = storage_type
74+
assert adapter.current_catalog_type == expected_current_type
7475

7576

7677
def test_get_catalog_type_cached(
@@ -103,7 +104,7 @@ def mock_fetchone(sql):
103104
assert fetchone_mock.call_count == 2
104105

105106

106-
@pytest.mark.parametrize("storage_type", ["hive", "delta"])
107+
@pytest.mark.parametrize("storage_type", ["hive", "delta_lake"])
107108
def test_partitioned_by_hive_delta(
108109
trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture, storage_type: str
109110
):
@@ -113,7 +114,8 @@ def test_partitioned_by_hive_delta(
113114
"sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_current_catalog",
114115
return_value=f"datalake_{storage_type}",
115116
)
116-
assert adapter.get_catalog_type(f"datalake_{storage_type}") == storage_type
117+
expected_type = storage_type
118+
assert adapter.get_catalog_type(f"datalake_{storage_type}") == expected_type
117119

118120
columns_to_types = {
119121
"cola": exp.DataType.build("INT"),
@@ -314,7 +316,7 @@ def test_comments_hive(mocker: MockerFixture, make_mocked_engine_adapter: t.Call
314316
]
315317

316318

317-
@pytest.mark.parametrize("storage_type", ["iceberg", "delta"])
319+
@pytest.mark.parametrize("storage_type", ["iceberg", "delta_lake"])
318320
def test_comments_iceberg_delta(
319321
mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, storage_type: str
320322
):
@@ -646,3 +648,40 @@ def test_session_authorization(trino_mocked_engine_adapter: TrinoEngineAdapter):
646648
"SELECT 1",
647649
"RESET SESSION AUTHORIZATION",
648650
]
651+
652+
653+
@pytest.mark.parametrize(
654+
"catalog_name,expected_replace",
655+
[
656+
("hive_catalog", False),
657+
("iceberg_catalog", True),
658+
("delta_catalog", False),
659+
("acme_delta_lake", True),
660+
("acme_iceberg", True),
661+
("custom_delta_lake_something", True),
662+
("my_iceberg_store", True),
663+
("plain_catalog", False),
664+
],
665+
)
666+
def test_replace_table_catalog_support(
667+
trino_mocked_engine_adapter: TrinoEngineAdapter, catalog_name, expected_replace
668+
):
669+
adapter = trino_mocked_engine_adapter
670+
671+
adapter.replace_query(
672+
table_name=".".join([catalog_name, "schema", "test_table"]),
673+
query_or_df=parse_one("SELECT 1 AS col"),
674+
)
675+
676+
sql_calls = to_sql_calls(adapter)
677+
assert len(sql_calls) == 1
678+
if expected_replace:
679+
assert (
680+
sql_calls[0]
681+
== f'CREATE OR REPLACE TABLE "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"'
682+
)
683+
else:
684+
assert (
685+
sql_calls[0]
686+
== f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"'
687+
)

0 commit comments

Comments
 (0)