Skip to content

Commit 07214e3

Browse files
committed
Remove time travel test for cloud engines, handle pyspark DFs in dbx
1 parent ecab844 commit 07214e3

File tree

2 files changed

+42
-34
lines changed

2 files changed

+42
-34
lines changed

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SourceQuery,
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
17+
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
1718
from sqlmesh.core.node import IntervalUnit
1819
from sqlmesh.core.schema_diff import SchemaDiffer
1920
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
@@ -380,23 +381,32 @@ def _record_execution_stats(
380381
except:
381382
return
382383

383-
history = self.cursor.fetchall_arrow()
384-
if history.num_rows:
385-
history_df = history.to_pandas()
386-
write_df = history_df[history_df["operation"] == "WRITE"]
387-
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
388-
if not write_df.empty:
389-
metrics = write_df["operationMetrics"][0]
390-
if metrics:
391-
rowcount = None
392-
rowcount_str = [
393-
metric[1] for metric in metrics if metric[0] == "numOutputRows"
394-
]
395-
if rowcount_str:
396-
try:
397-
rowcount = int(rowcount_str[0])
398-
except (TypeError, ValueError):
399-
pass
384+
history = (
385+
self.cursor.fetchdf()
386+
if isinstance(self.cursor, SparkSessionCursor)
387+
else self.cursor.fetchall_arrow()
388+
)
389+
if history is not None:
390+
import pandas as pd
391+
392+
history_df = (
393+
history.to_pandas() if not isinstance(history, pd.DataFrame) else history # type: ignore
394+
)
395+
if not history_df.empty:
396+
write_df = history_df[history_df["operation"] == "WRITE"]
397+
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
398+
if not write_df.empty:
399+
metrics = write_df["operationMetrics"][0]
400+
if metrics:
401+
rowcount = None
402+
rowcount_str = [
403+
metric[1] for metric in metrics if metric[0] == "numOutputRows"
404+
]
405+
if rowcount_str:
406+
try:
407+
rowcount = int(rowcount_str[0])
408+
except (TypeError, ValueError):
409+
pass
400410

401411
bytes_processed = None
402412
bytes_str = [

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,23 +2419,21 @@ def capture_execution_stats(
24192419
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24202420

24212421
# run that loads 0 rows in incremental model
2422-
actual_execution_stats = {}
2423-
with patch.object(
2424-
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2425-
):
2426-
with time_machine.travel(date.today() + timedelta(days=1)):
2427-
context.run()
2428-
2429-
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2430-
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2431-
# snowflake doesn't track rows for CTAS
2432-
assert actual_execution_stats["full_model"].total_rows_processed == (
2433-
None if ctx.mark.startswith("snowflake") else 3
2434-
)
2435-
2436-
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2437-
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2438-
assert actual_execution_stats["full_model"].total_bytes_processed is not None
2422+
# - some cloud DBs error because time travel messes up token expiration
2423+
if not ctx.is_remote:
2424+
actual_execution_stats = {}
2425+
with patch.object(
2426+
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2427+
):
2428+
with time_machine.travel(date.today() + timedelta(days=1)):
2429+
context.run()
2430+
2431+
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2432+
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2433+
# snowflake doesn't track rows for CTAS
2434+
assert actual_execution_stats["full_model"].total_rows_processed == (
2435+
None if ctx.mark.startswith("snowflake") else 3
2436+
)
24392437

24402438
# make and validate unmodified dev environment
24412439
no_change_plan: Plan = context.plan_builder(

0 commit comments

Comments
 (0)