Skip to content

Commit 28277ff

Browse files
committed
Fix: Fix Transactions with Truncate (#1880)
* fix transactions with truncate * fix tests (cherry picked from commit a6d8d90)
1 parent 6343d98 commit 28277ff

File tree

11 files changed

+117
-42
lines changed

11 files changed

+117
-42
lines changed

sqlmesh/core/config/connection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ def _connection_kwargs_keys(self) -> t.Set[str]:
265265
def _engine_adapter(self) -> t.Type[EngineAdapter]:
266266
return engine_adapter.SnowflakeEngineAdapter
267267

268+
@property
269+
def _static_connection_kwargs(self) -> t.Dict[str, t.Any]:
270+
return {"autocommit": False}
271+
268272
@property
269273
def _connection_factory(self) -> t.Callable:
270274
from snowflake import connector

sqlmesh/core/engine_adapter/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,9 +1551,9 @@ def _add_where_to_query(
15511551

15521552
return query
15531553

1554-
def _truncate_table(self, table_name: TableName) -> str:
1555-
table = quote_identifiers(exp.to_table(table_name))
1556-
return f"TRUNCATE {table.sql(dialect=self.dialect)}"
1554+
def _truncate_table(self, table_name: TableName) -> None:
1555+
table = exp.to_table(table_name)
1556+
self.execute(f"TRUNCATE TABLE {table.sql(dialect=self.dialect, identify=True)}")
15571557

15581558

15591559
class EngineAdapterWithIndexSupport(EngineAdapter):

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def replace_table(
100100
temp_query = query.transform(
101101
replace_table, curr_table=target_table, new_table=temp_table
102102
)
103-
engine_adapter.execute(engine_adapter._truncate_table(target_table))
103+
engine_adapter._truncate_table(target_table)
104104
return engine_adapter._insert_append_query(target_table, temp_query, columns_to_types)
105105

106106
def replace_query(
@@ -138,7 +138,7 @@ def replace_query(
138138
return self.overwrite_target_from_temp(
139139
self, query, columns_to_types, target_table
140140
)
141-
self.execute(self._truncate_table(table_name))
141+
self._truncate_table(table_name)
142142
return self._insert_append_query(table_name, query, columns_to_types)
143143

144144

@@ -260,3 +260,12 @@ def get_current_catalog(self) -> t.Optional[str]:
260260
if result:
261261
return result[0]
262262
return None
263+
264+
265+
class NonTransactionalTruncateMixin(EngineAdapter):
266+
def _truncate_table(self, table_name: TableName) -> None:
267+
# Truncate forces a commit of the current transaction so we want to do an unconditional delete to
268+
# preserve the transaction if one exists otherwise we can truncate
269+
if self._connection_pool.is_transaction_active:
270+
return self.execute(exp.Delete(this=exp.to_table(table_name)))
271+
super()._truncate_table(table_name)

sqlmesh/core/engine_adapter/mssql.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import pandas as pd
99
from pandas.api.types import is_datetime64_any_dtype # type: ignore
1010
from sqlglot import exp
11-
from sqlglot.optimizer.qualify_columns import quote_identifiers
1211

1312
from sqlmesh.core.engine_adapter.base import (
1413
CatalogSupport,
@@ -243,10 +242,6 @@ def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
243242
for row in dataframe.itertuples()
244243
]
245244

246-
def _truncate_table(self, table_name: TableName) -> str:
247-
table = quote_identifiers(exp.to_table(table_name))
248-
return f"TRUNCATE TABLE {table.sql(dialect=self.dialect)}"
249-
250245
def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str:
251246
sql = super()._to_sql(expression, quote=quote, **kwargs)
252247
return f"{sql};"

sqlmesh/core/engine_adapter/mysql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlmesh.core.engine_adapter.mixins import (
66
LogicalMergeMixin,
77
LogicalReplaceQueryMixin,
8+
NonTransactionalTruncateMixin,
89
PandasNativeFetchDFSupportMixin,
910
)
1011
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType, set_catalog
@@ -17,6 +18,7 @@ class MySQLEngineAdapter(
1718
LogicalMergeMixin,
1819
LogicalReplaceQueryMixin,
1920
PandasNativeFetchDFSupportMixin,
21+
NonTransactionalTruncateMixin,
2022
):
2123
DEFAULT_BATCH_SIZE = 200
2224
DIALECT = "mysql"

sqlmesh/core/engine_adapter/redshift.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import typing as t
4-
import uuid
54

65
import pandas as pd
76
from sqlglot import exp
@@ -12,6 +11,7 @@
1211
GetCurrentCatalogFromFunctionMixin,
1312
LogicalMergeMixin,
1413
LogicalReplaceQueryMixin,
14+
NonTransactionalTruncateMixin,
1515
)
1616
from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType, set_catalog
1717

@@ -25,6 +25,7 @@ class RedshiftEngineAdapter(
2525
LogicalReplaceQueryMixin,
2626
LogicalMergeMixin,
2727
GetCurrentCatalogFromFunctionMixin,
28+
NonTransactionalTruncateMixin,
2829
):
2930
DIALECT = "redshift"
3031
ESCAPE_JSON = True
@@ -180,12 +181,8 @@ def replace_query(
180181
columns_to_types = columns_to_types or self.columns(table_name)
181182
target_table = exp.to_table(table_name)
182183
with self.transaction():
183-
temp_table_name = f"{target_table.alias_or_name}_temp_{self._short_hash()}"
184-
temp_table = target_table.copy()
185-
temp_table.set("this", exp.to_identifier(temp_table_name))
186-
old_table_name = f"{target_table.alias_or_name}_old_{self._short_hash()}"
187-
old_table = target_table.copy()
188-
old_table.set("this", exp.to_identifier(old_table_name))
184+
temp_table = self._get_temp_table(target_table)
185+
old_table = self._get_temp_table(target_table)
189186
self.create_table(temp_table, columns_to_types, exists=False, **kwargs)
190187
self._insert_append_source_queries(temp_table, source_queries, columns_to_types)
191188
self.rename_table(target_table, old_table)
@@ -200,17 +197,18 @@ def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
200197
"""
201198
Returns all the data objects that exist in the given schema and optionally catalog.
202199
"""
200+
catalog_name = self.get_current_catalog()
203201
query = f"""
204202
SELECT
205-
null AS catalog_name,
203+
'{catalog_name}' AS catalog_name,
206204
tablename AS name,
207205
schemaname AS schema_name,
208206
'TABLE' AS type
209207
FROM pg_tables
210208
WHERE schemaname ILIKE '{schema_name}'
211209
UNION ALL
212210
SELECT
213-
null AS catalog_name,
211+
'{catalog_name}' AS catalog_name,
214212
viewname AS name,
215213
schemaname AS schema_name,
216214
'VIEW' AS type
@@ -219,7 +217,7 @@ def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
219217
AND definition not ilike '%create materialized view%'
220218
UNION ALL
221219
SELECT
222-
null AS catalog_name,
220+
'{catalog_name}' AS catalog_name,
223221
viewname AS name,
224222
schemaname AS schema_name,
225223
'MATERIALIZED_VIEW' AS type
@@ -236,9 +234,6 @@ def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
236234
for row in df.itertuples()
237235
]
238236

239-
def _short_hash(self) -> str:
240-
return uuid.uuid4().hex[:8]
241-
242237

243238
def parse_plan(plan: str) -> t.Optional[t.Dict]:
244239
"""Parse the output of a redshift explain verbose query plan into a Python dict."""

sqlmesh/core/engine_adapter/spark.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,6 @@ def wap_publish(self, table_name: TableName, wap_id: str) -> None:
500500
)
501501
self.execute(f"ALTER TABLE {fqn.sql(dialect=self.dialect)} DROP BRANCH {branch_name}")
502502

503-
def _truncate_table(self, table_name: TableName) -> str:
504-
table = quote_identifiers(exp.to_table(table_name))
505-
return f"TRUNCATE TABLE {table.sql(dialect=self.dialect)}"
506-
507503
def _ensure_fqn(self, table_name: TableName) -> exp.Table:
508504
if isinstance(table_name, exp.Table):
509505
table_name = table_name.copy()

sqlmesh/core/engine_adapter/trino.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pandas as pd
66
from pandas.api.types import is_datetime64_any_dtype # type: ignore
77
from sqlglot import exp
8-
from sqlglot.optimizer.qualify_columns import quote_identifiers
98

109
from sqlmesh.core.dialect import to_schema
1110
from sqlmesh.core.engine_adapter.base import (
@@ -68,9 +67,10 @@ def _insert_overwrite_by_condition(
6867
f"SET SESSION {self.get_current_catalog()}.insert_existing_partitions_behavior='APPEND'"
6968
)
7069

71-
def _truncate_table(self, table_name: TableName) -> str:
72-
table = quote_identifiers(exp.to_table(table_name))
73-
return f"DELETE FROM {table.sql(dialect=self.dialect)}"
70+
def _truncate_table(self, table_name: TableName) -> None:
71+
table = exp.to_table(table_name)
72+
# Some trino connectors don't support truncate so we use delete.
73+
self.execute(f"DELETE FROM {table.sql(dialect=self.dialect, identify=True)}")
7474

7575
def _get_data_objects(self, schema_name: SchemaName) -> t.List[DataObject]:
7676
"""

tests/core/engine_adapter/test_integration.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,53 @@ def test_scd_type_2(ctx: TestContext):
883883
)
884884

885885

886+
def test_truncate_table(ctx: TestContext):
887+
if ctx.test_type != "query":
888+
pytest.skip("Truncate table test does not change based on input data type")
889+
890+
ctx.init()
891+
table = ctx.table("test_table")
892+
ctx.engine_adapter.create_table(table, ctx.columns_to_types)
893+
input_data = pd.DataFrame(
894+
[
895+
{"id": 1, "ds": "2022-01-01"},
896+
{"id": 2, "ds": "2022-01-02"},
897+
{"id": 3, "ds": "2022-01-03"},
898+
]
899+
)
900+
ctx.engine_adapter.insert_append(table, ctx.input_data(input_data))
901+
ctx.compare_with_current(table, input_data)
902+
ctx.engine_adapter._truncate_table(table)
903+
assert ctx.engine_adapter.fetchone(exp.select("count(*)").from_(table))[0] == 0
904+
905+
906+
def test_transaction(ctx: TestContext):
907+
if ctx.engine_adapter.SUPPORTS_TRANSACTIONS is False:
908+
pytest.skip(f"Engine adapter {ctx.engine_adapter.dialect} doesn't support transactions")
909+
if ctx.test_type != "query":
910+
pytest.skip("Transaction test can just run for query")
911+
912+
ctx.init()
913+
table = ctx.table("test_table")
914+
input_data = pd.DataFrame(
915+
[
916+
{"id": 1, "ds": "2022-01-01"},
917+
{"id": 2, "ds": "2022-01-02"},
918+
{"id": 3, "ds": "2022-01-03"},
919+
]
920+
)
921+
with ctx.engine_adapter.transaction():
922+
ctx.engine_adapter.create_table(table, ctx.columns_to_types)
923+
ctx.engine_adapter.insert_append(
924+
table, ctx.input_data(input_data, ctx.columns_to_types), ctx.columns_to_types
925+
)
926+
ctx.compare_with_current(table, input_data)
927+
with ctx.engine_adapter.transaction():
928+
ctx.engine_adapter._truncate_table(table)
929+
ctx.engine_adapter._connection_pool.rollback()
930+
ctx.compare_with_current(table, input_data)
931+
932+
886933
def test_sushi(ctx: TestContext):
887934
if ctx.test_type != "query":
888935
pytest.skip("Sushi end-to-end tests only need to run for query and pyspark tests")

tests/core/engine_adapter/test_mixins.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlmesh.core.engine_adapter.mixins import (
99
LogicalMergeMixin,
1010
LogicalReplaceQueryMixin,
11+
NonTransactionalTruncateMixin,
1112
)
1213
from tests.core.engine_adapter import to_sql_calls
1314

@@ -30,7 +31,7 @@ def test_logical_replace_query_already_exists(
3031
adapter.replace_query("db.table", parse_one("SELECT col FROM db.other_table"))
3132

3233
assert to_sql_calls(adapter) == [
33-
'TRUNCATE "db"."table"',
34+
'TRUNCATE TABLE "db"."table"',
3435
'INSERT INTO "db"."table" ("col") SELECT "col" FROM "db"."other_table"',
3536
]
3637

@@ -76,7 +77,7 @@ def test_logical_replace_self_reference(
7677
assert to_sql_calls(adapter) == [
7778
f'CREATE SCHEMA IF NOT EXISTS "db"',
7879
f'CREATE TABLE IF NOT EXISTS "db"."__temp_table_{temp_table_id}" AS SELECT "col" FROM "db"."table"',
79-
'TRUNCATE "db"."table"',
80+
'TRUNCATE TABLE "db"."table"',
8081
f'INSERT INTO "db"."table" ("col") SELECT "col" + 1 AS "col" FROM "db"."__temp_table_{temp_table_id}"',
8182
f'DROP TABLE IF EXISTS "db"."__temp_table_{temp_table_id}"',
8283
]
@@ -135,3 +136,23 @@ def test_logical_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFix
135136
call('''DROP TABLE IF EXISTS "temporary"'''),
136137
]
137138
)
139+
140+
141+
def test_non_transaction_truncate_mixin(
142+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
143+
):
144+
adapter = make_mocked_engine_adapter(NonTransactionalTruncateMixin, "redshift")
145+
adapter._truncate_table(table_name="test_table")
146+
147+
assert to_sql_calls(adapter) == ['TRUNCATE TABLE "test_table"']
148+
149+
150+
def test_non_transaction_truncate_mixin_within_transaction(
151+
make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable
152+
):
153+
adapter = make_mocked_engine_adapter(NonTransactionalTruncateMixin, "redshift")
154+
adapter._connection_pool = mocker.MagicMock()
155+
adapter._connection_pool.is_transaction_active = True
156+
adapter._truncate_table(table_name="test_table")
157+
158+
assert to_sql_calls(adapter) == ['DELETE FROM "test_table"']

0 commit comments

Comments
 (0)