Skip to content

Commit 2fb3daf

Browse files
committed
Handle snowflake lack of CTAS tracking
1 parent d6196fe commit 2fb3daf

File tree

4 files changed

+55
-6
lines changed

4 files changed

+55
-6
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2413,6 +2413,11 @@ def _log_sql(
24132413

24142414
logger.log(self._execute_log_level, "Executing SQL: %s", sql_to_log)
24152415

2416+
def _record_execution_stats(
2417+
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
2418+
) -> None:
2419+
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)
2420+
24162421
def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.Any) -> None:
24172422
self.cursor.execute(sql, **kwargs)
24182423

@@ -2429,7 +2434,7 @@ def _execute(self, sql: str, track_execution_stats: bool = False, **kwargs: t.An
24292434
except (TypeError, ValueError):
24302435
pass
24312436

2432-
QueryExecutionTracker.record_execution(sql, rowcount, None)
2437+
self._record_execution_stats(sql, rowcount)
24332438

24342439
@contextlib.contextmanager
24352440
def temp_table(

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import logging
5+
import re
56
import typing as t
67

78
from sqlglot import exp
@@ -24,6 +25,7 @@
2425
set_catalog,
2526
)
2627
from sqlmesh.core.schema_diff import SchemaDiffer
28+
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
2729
from sqlmesh.utils import optional_import, get_source_columns_to_types
2830
from sqlmesh.utils.errors import SQLMeshError
2931
from sqlmesh.utils.pandas import columns_to_types_from_dtypes
@@ -72,6 +74,7 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin, ClusteredByMixi
7274
)
7375
MANAGED_TABLE_KIND = "DYNAMIC TABLE"
7476
SNOWPARK = "snowpark"
77+
SUPPORTS_QUERY_EXECUTION_TRACKING = True
7578

7679
@contextlib.contextmanager
7780
def session(self, properties: SessionProperties) -> t.Iterator[None]:
@@ -664,3 +667,33 @@ def close(self) -> t.Any:
664667
self._connection_pool.set_attribute(self.SNOWPARK, None)
665668

666669
return super().close()
670+
671+
def _record_execution_stats(
672+
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
673+
) -> None:
674+
"""Snowflake does not report row counts for CTAS like other DML operations.
675+
676+
They neither report the sentinel value -1 nor do they report 0 rows. Instead, they return a single data row
677+
containing the string "Table <table_name> successfully created." and a row count of 1.
678+
679+
We do not want to record the row count of 1 for CTAS operations, so we check for that data pattern and return
680+
early if it is detected.
681+
682+
Regex explanation - Snowflake identifiers may be:
683+
- An unquoted contiguous set of [a-zA-Z0-9_$] characters
684+
- A double-quoted string that may contain spaces and nested double-quotes represented by `""`
685+
- Example: " my ""table"" name "
686+
- Pattern: "(?:[^"]|"")+"
687+
- ?: is a non-capturing group
688+
- [^"] matches any single character except a double-quote
689+
- "" matches two sequential double-quotes
690+
"""
691+
if rowcount == 1:
692+
results = self.cursor.fetchall()
693+
if results and len(results) == 1:
694+
is_ctas = re.match(
695+
r'Table ([a-zA-Z0-9_$]+|"(?:[^"]|"")+") successfully created\.', results[0][0]
696+
)
697+
if is_ctas:
698+
return
699+
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)

sqlmesh/core/snapshot/execution_tracker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
@dataclass
1111
class QueryExecutionStats:
1212
snapshot_batch_id: str
13-
total_rows_processed: int = 0
14-
total_bytes_processed: int = 0
13+
total_rows_processed: t.Optional[int] = None
14+
total_bytes_processed: t.Optional[int] = None
1515
query_count: int = 0
1616
queries_executed: t.List[t.Tuple[str, t.Optional[int], t.Optional[int], float]] = field(
1717
default_factory=list
@@ -42,12 +42,18 @@ def add_execution(
4242
self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
4343
) -> None:
4444
if row_count is not None and row_count >= 0:
45-
self.stats.total_rows_processed += row_count
45+
if self.stats.total_rows_processed is None:
46+
self.stats.total_rows_processed = row_count
47+
else:
48+
self.stats.total_rows_processed += row_count
4649

4750
# conditional on row_count because we should only count bytes corresponding to
4851
# DML actions whose rows were captured
4952
if bytes_processed is not None and bytes_processed >= 0:
50-
self.stats.total_bytes_processed += bytes_processed
53+
if self.stats.total_bytes_processed is None:
54+
self.stats.total_bytes_processed = bytes_processed
55+
else:
56+
self.stats.total_bytes_processed += bytes_processed
5157

5258
self.stats.query_count += 1
5359
# TODO: remove this

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2413,7 +2413,12 @@ def capture_execution_stats(
24132413

24142414
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
24152415
assert actual_execution_stats["incremental_model"].total_rows_processed == 7
2416-
assert actual_execution_stats["full_model"].total_rows_processed == 3
2416+
# snowflake doesn't track rows for CTAS
2417+
assert actual_execution_stats["full_model"].total_rows_processed == (
2418+
None if ctx.mark.startswith("snowflake") else 3
2419+
)
2420+
# seed rows aren't tracked
2421+
assert actual_execution_stats["seed_model"].total_rows_processed is None
24172422

24182423
if ctx.mark.startswith("bigquery"):
24192424
assert actual_execution_stats["incremental_model"].total_bytes_processed

0 commit comments

Comments
 (0)