Skip to content

Commit b7a6937

Browse files
committed
Add databricks support
1 parent 27a8074 commit b7a6937

File tree

3 files changed

+60
-8
lines changed

3 files changed

+60
-8
lines changed

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import typing as t
55
from functools import partial
66

7-
from sqlglot import exp
7+
from sqlglot import exp, parse_one
88
from sqlmesh.core.dialect import to_schema
99
from sqlmesh.core.engine_adapter.shared import (
1010
CatalogSupport,
@@ -16,6 +16,7 @@
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
1717
from sqlmesh.core.node import IntervalUnit
1818
from sqlmesh.core.schema_diff import SchemaDiffer
19+
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
1920
from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection
2021
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
2122

@@ -34,6 +35,7 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
3435
SUPPORTS_CLONING = True
3536
SUPPORTS_MATERIALIZED_VIEWS = True
3637
SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True
38+
SUPPORTS_QUERY_EXECUTION_TRACKING = True
3739
SCHEMA_DIFFER = SchemaDiffer(
3840
support_positional_add=True,
3941
support_nested_operations=True,
@@ -364,3 +366,52 @@ def _build_table_properties_exp(
364366
expressions.append(clustered_by_exp)
365367
properties = exp.Properties(expressions=expressions)
366368
return properties
369+
370+
def _record_execution_stats(
371+
self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None
372+
) -> None:
373+
parsed = parse_one(sql, dialect=self.dialect)
374+
table = parsed.find(exp.Table)
375+
table_name = table.sql(dialect=self.dialect) if table else None
376+
377+
if table_name:
378+
try:
379+
self.cursor.execute(f"DESCRIBE HISTORY {table_name}")
380+
except:
381+
return
382+
383+
history = self.cursor.fetchall_arrow()
384+
if history.num_rows:
385+
history_df = history.to_pandas()
386+
write_df = history_df[history_df["operation"] == "WRITE"]
387+
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
388+
if not write_df.empty:
389+
metrics = write_df["operationMetrics"][0]
390+
if metrics:
391+
rowcount = None
392+
rowcount_str = [
393+
metric[1] for metric in metrics if metric[0] == "numOutputRows"
394+
]
395+
if rowcount_str:
396+
try:
397+
rowcount = int(rowcount_str[0])
398+
except (TypeError, ValueError):
399+
pass
400+
401+
bytes_processed = None
402+
bytes_str = [
403+
metric[1] for metric in metrics if metric[0] == "numOutputBytes"
404+
]
405+
if bytes_str:
406+
try:
407+
bytes_processed = int(bytes_str[0])
408+
except (TypeError, ValueError):
409+
pass
410+
411+
if rowcount is not None or bytes_processed is not None:
412+
# if no rows were written, df contains 0 for bytes but no value for rows
413+
rowcount = (
414+
0 if rowcount is None and bytes_processed is not None else rowcount
415+
)
416+
417+
QueryExecutionTracker.record_execution(sql, rowcount, bytes_processed)

sqlmesh/core/snapshot/execution_tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __post_init__(self) -> None:
4141
def add_execution(
4242
self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int]
4343
) -> None:
44-
if row_count is not None:
44+
if row_count is not None and row_count >= 0:
4545
if self.stats.total_rows_processed is None:
4646
self.stats.total_rows_processed = row_count
4747
else:

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,11 +2414,12 @@ def capture_execution_stats(
24142414
# seed rows aren't tracked
24152415
assert actual_execution_stats["seed_model"].total_rows_processed is None
24162416

2417-
if ctx.mark.startswith("bigquery"):
2418-
assert actual_execution_stats["incremental_model"].total_bytes_processed
2419-
assert actual_execution_stats["full_model"].total_bytes_processed
2417+
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2418+
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2419+
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24202420

24212421
# run that loads 0 rows in incremental model
2422+
actual_execution_stats = {}
24222423
with patch.object(
24232424
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
24242425
):
@@ -2432,9 +2433,9 @@ def capture_execution_stats(
24322433
None if ctx.mark.startswith("snowflake") else 3
24332434
)
24342435

2435-
if ctx.mark.startswith("bigquery"):
2436-
assert actual_execution_stats["incremental_model"].total_bytes_processed
2437-
assert actual_execution_stats["full_model"].total_bytes_processed
2436+
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2437+
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2438+
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24382439

24392440
# make and validate unmodified dev environment
24402441
no_change_plan: Plan = context.plan_builder(

0 commit comments

Comments
 (0)