Skip to content

Commit 67377d8

Browse files
committed
Remove time travel test for cloud engines, handle pyspark DFs in dbx
1 parent ecab844 commit 67377d8

File tree

2 files changed

+40
-34
lines changed

2 files changed

+40
-34
lines changed

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SourceQuery,
1515
)
1616
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
17+
from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor
1718
from sqlmesh.core.node import IntervalUnit
1819
from sqlmesh.core.schema_diff import SchemaDiffer
1920
from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker
@@ -380,23 +381,30 @@ def _record_execution_stats(
380381
except:
381382
return
382383

383-
history = self.cursor.fetchall_arrow()
384-
if history.num_rows:
385-
history_df = history.to_pandas()
386-
write_df = history_df[history_df["operation"] == "WRITE"]
387-
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
388-
if not write_df.empty:
389-
metrics = write_df["operationMetrics"][0]
390-
if metrics:
391-
rowcount = None
392-
rowcount_str = [
393-
metric[1] for metric in metrics if metric[0] == "numOutputRows"
394-
]
395-
if rowcount_str:
396-
try:
397-
rowcount = int(rowcount_str[0])
398-
except (TypeError, ValueError):
399-
pass
384+
history = (
385+
self.cursor.fetchdf()
386+
if isinstance(self.cursor, SparkSessionCursor)
387+
else self.cursor.fetchall_arrow()
388+
)
389+
if history is not None:
390+
history_df = (
391+
history.to_pandas() if not isinstance(history, pd.DataFrame) else history # type: ignore
392+
)
393+
if not history_df.empty:
394+
write_df = history_df[history_df["operation"] == "WRITE"]
395+
write_df = write_df[write_df["timestamp"] == write_df["timestamp"].max()]
396+
if not write_df.empty:
397+
metrics = write_df["operationMetrics"][0]
398+
if metrics:
399+
rowcount = None
400+
rowcount_str = [
401+
metric[1] for metric in metrics if metric[0] == "numOutputRows"
402+
]
403+
if rowcount_str:
404+
try:
405+
rowcount = int(rowcount_str[0])
406+
except (TypeError, ValueError):
407+
pass
400408

401409
bytes_processed = None
402410
bytes_str = [

tests/core/engine_adapter/integration/test_integration.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,23 +2419,21 @@ def capture_execution_stats(
24192419
assert actual_execution_stats["full_model"].total_bytes_processed is not None
24202420

24212421
# run that loads 0 rows in incremental model
2422-
actual_execution_stats = {}
2423-
with patch.object(
2424-
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2425-
):
2426-
with time_machine.travel(date.today() + timedelta(days=1)):
2427-
context.run()
2428-
2429-
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2430-
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2431-
# snowflake doesn't track rows for CTAS
2432-
assert actual_execution_stats["full_model"].total_rows_processed == (
2433-
None if ctx.mark.startswith("snowflake") else 3
2434-
)
2435-
2436-
if ctx.mark.startswith("bigquery") or ctx.mark.startswith("databricks"):
2437-
assert actual_execution_stats["incremental_model"].total_bytes_processed is not None
2438-
assert actual_execution_stats["full_model"].total_bytes_processed is not None
2422+
# - some cloud DBs error because time travel messes up token expiration
2423+
if not ctx.is_remote:
2424+
actual_execution_stats = {}
2425+
with patch.object(
2426+
context.console, "update_snapshot_evaluation_progress", capture_execution_stats
2427+
):
2428+
with time_machine.travel(date.today() + timedelta(days=1)):
2429+
context.run()
2430+
2431+
if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING:
2432+
assert actual_execution_stats["incremental_model"].total_rows_processed == 0
2433+
# snowflake doesn't track rows for CTAS
2434+
assert actual_execution_stats["full_model"].total_rows_processed == (
2435+
None if ctx.mark.startswith("snowflake") else 3
2436+
)
24392437

24402438
# make and validate unmodified dev environment
24412439
no_change_plan: Plan = context.plan_builder(

0 commit comments

Comments
 (0)