Skip to content

Commit 163ad47

Browse files
committed
Remove databricks, snowflake metadata calls
1 parent d425f85 commit 163ad47

File tree

3 files changed

+35
-101
lines changed

3 files changed

+35
-101
lines changed

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import typing as t
55
from functools import partial
66

7-
from sqlglot import exp, parse_one
7+
from sqlglot import exp
88
from sqlmesh.core.dialect import to_schema
99
from sqlmesh.core.engine_adapter.shared import (
1010
CatalogSupport,
@@ -14,10 +14,8 @@
1414
SourceQuery,
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
17-
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
1817
from sqlmesh.core.node import IntervalUnit
1918
from sqlmesh.core.schema_diff import NestedSupport
20-
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
2119
from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection
2220
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
2321

@@ -36,7 +34,6 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
3634
SUPPORTS_CLONING = True
3735
SUPPORTS_MATERIALIZED_VIEWS = True
3836
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
39-
SUPPORTS_QUERY_EXECUTION_TRACKING = True
4037
SCHEMA_DIFFER_KWARGS = {
4138
"support_positional_add": True,
4239
"nested_support": NestedSupport.ALL,
@@ -366,73 +363,3 @@ def _build_table_properties_exp(
366363
expressions.append(clustered_by_exp)
367364
properties = exp.Properties(expressions=expressions)
368365
return properties
369-
370-
def _record_execution_stats(
371-
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
372-
) -> None:
373-
parsed = parse_one(sql, dialect=self.dialect)
374-
table = parsed.find(exp.Table)
375-
table_name = table.sql(dialect=self.dialect) if table else None
376-
377-
if table_name:
378-
try:
379-
self.cursor.execute(f"DESCRIBE HISTORY {table_name}")
380-
except:
381-
return
382-
383-
history = (
384-
self.cursor.fetchdf()
385-
if isinstance(self.cursor, SparkSessionCursor)
386-
else self.cursor.fetchall_arrow()
387-
)
388-
if history is not None:
389-
from pandas import DataFrame as PandasDataFrame
390-
from pyspark.sql import DataFrame as PySparkDataFrame
391-
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
392-
393-
history_df = None
394-
if isinstance(history, PandasDataFrame):
395-
history_df = history
396-
elif isinstance(history, (PySparkDataFrame, PySparkConnectDataFrame)):
397-
history_df = history.toPandas()
398-
else:
399-
# arrow table
400-
history_df = history.to_pandas()
401-
402-
if history_df is not None and not history_df.empty:
403-
write_df = history_df[history_df["operation"] == "WRITE"]
404-
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
405-
if not write_df.empty and "operationMetrics" in write_df.columns:
406-
metrics = write_df["operationMetrics"].iloc[0]
407-
if metrics:
408-
rowcount = None
409-
rowcount_str = [
410-
metric[1] for metric in metrics if metric[0] == "numOutputRows"
411-
]
412-
if rowcount_str:
413-
try:
414-
rowcount = int(rowcount_str[0])
415-
except (TypeError, ValueError):
416-
pass
417-
418-
bytes_processed = None
419-
bytes_str = [
420-
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
421-
]
422-
if bytes_str:
423-
try:
424-
bytes_processed = int(bytes_str[0])
425-
except (TypeError, ValueError):
426-
pass
427-
428-
if rowcount is not None or bytes_processed is not None:
429-
# if no rows were written, df contains 0 for bytes but no value for rows
430-
rowcount = (
431-
0
432-
if rowcount is None and bytes_processed is not None
433-
else rowcount
434-
)
435-
436-
QueryExecutionTracker.record_execution(
437-
sql, rowcount, bytes_processed
438-
)

sqlmesh/core/engine_adapter/snowflake.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import contextlib
44
import logging
5+
import re
56
import typing as t
67

7-
from sqlglot import exp, parse_one
8+
from sqlglot import exp
89
from sqlglot.helper import ensure_list
910
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1011
from sqlglot.optimizer.qualify_columns import quote_identifiers
@@ -672,28 +673,31 @@ def _record_execution_stats(
672673
) -> None:
673674
"""Snowflake does not report row counts for CTAS like other DML operations.
674675
675-
They neither report the sentinel value -1 nor do they report 0 rows. Instead, they return a single data row
676-
containing the string "Table <table_name> successfully created." and a row count of 1.
676+
They neither report the sentinel value -1 nor do they report 0 rows. Instead, they report a rowcount
677+
of 1 and return a single data row containing one of the strings:
678+
- "Table <table_name> successfully created."
679+
- "<table_name> already exists, statement succeeded."
677680
678-
We do not want to record the incorrect row count of 1, so we check whether:
679-
- There is exactly one row to fetch (in general, DML operations should return no rows to fetch from the cursor)
680-
- That row contains the table successfully created string
681-
682-
If so, we return early and do not record the row count.
681+
We do not want to record the incorrect row count of 1, so we check whether that row contains the table
682+
successfully created string. If so, we return early and do not record the row count.
683683
"""
684684
if rowcount == 1:
685-
query_parsed = parse_one(sql, dialect=self.dialect)
686-
if isinstance(query_parsed, exp.Create):
687-
if query_parsed.expression and isinstance(query_parsed.expression, exp.Select):
688-
table = query_parsed.find(exp.Table)
689-
if table:
690-
row_query = f"SELECT ROW_COUNT as row_count FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{table.db}' AND TABLE_NAME = '{table.name}'"
691-
row_query_results = self.fetchone(row_query, quote_identifiers=True)
692-
if row_query_results:
693-
rowcount = row_query_results[0]
694-
else:
695-
return
696-
else:
685+
results = self.cursor.fetchone()
686+
if results:
687+
try:
688+
results_str = str(results[0])
689+
except (ValueError, TypeError):
690+
return
691+
692+
# Snowflake identifiers may be:
693+
# - An unquoted contiguous set of [a-zA-Z0-9_$] characters
694+
# - A double-quoted string that may contain spaces and nested double-quotes represented by `""`. Example: " my ""table"" name "
695+
is_created = re.match(r'Table [a-zA-Z0-9_$"]*? successfully created\.', results_str)
696+
is_already_exists = re.match(
697+
r'[a-zA-Z0-9_$"]*? already exists, statement succeeded\.',
698+
results_str,
699+
)
700+
if is_created or is_already_exists:
697701
return
698702

699703
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
import sys
88
import typing as t
99
import shutil
10-
from datetime import date, datetime, timedelta
10+
from datetime import datetime, timedelta
1111
from unittest.mock import patch
1212
import numpy as np # noqa: TID253
1313
import pandas as pd # noqa: TID253
1414
import pytest
1515
import pytz
16-
import time_machine
1716
from sqlglot import exp, parse_one
1817
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1918
from sqlglot.optimizer.qualify_columns import quote_identifiers
@@ -2457,14 +2456,18 @@ def capture_execution_stats(
24572456
assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3
24582457

24592458
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2460-
assert actual_execution_stats["seed_model"].total_rows_processed == 7
24612459
assert actual_execution_stats["incremental_model"].total_rows_processed == 7
24622460
# snowflake doesn't track rows for CTAS
2463-
assert actual_execution_stats["full_model"].total_rows_processed == 3
2461+
assert actual_execution_stats["full_model"].total_rows_processed == (
2462+
None if ctx.mark.startswith("snowflake") else 3
2463+
)
2464+
assert actual_execution_stats["seed_model"].total_rows_processed == (
2465+
None if ctx.mark.startswith("snowflake") else 7
2466+
)
24642467

2465-
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2466-
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2467-
assert actual_execution_stats["full_model"].total_bytes_processed is not None
2468+
if ctx.mark.startswith("bigquery"):
2469+
assert actual_execution_stats["incremental_model"].total_bytes_processed
2470+
assert actual_execution_stats["full_model"].total_bytes_processed
24682471

24692472
# run that loads 0 rows in incremental model
24702473
# - some cloud DBs error because time travel messes up token expiration

0 commit comments

Comments
 (0)