Skip to content

Commit 86ba6d6

Browse files
committed
Move all tracking into snapshot evaluator, remove seed tracker class
1 parent 03900c2 commit 86ba6d6

File tree

6 files changed

+161
-240
lines changed

6 files changed

+161
-240
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from sqlmesh.core.model.kind import TimeColumn
4242
from sqlmesh.core.schema_diff import SchemaDiffer
43-
from sqlmesh.core.execution_tracker import record_execution as track_execution_record
43+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
4444
from sqlmesh.utils import CorrelationId, columns_to_types_all_known, random_id
4545
from sqlmesh.utils.connection_pool import ConnectionPool, create_connection_pool
4646
from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column
@@ -2283,7 +2283,11 @@ def _log_sql(
22832283
def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) -> None:
22842284
self.cursor.execute(sql, **kwargs)
22852285

2286-
if track_row_count and self.SUPPORTS_QUERY_EXECUTION_TRACKING:
2286+
if (
2287+
self.SUPPORTS_QUERY_EXECUTION_TRACKING
2288+
and track_row_count
2289+
and QueryExecutionTracker.is_tracking()
2290+
):
22872291
rowcount_raw = getattr(self.cursor, "rowcount", None)
22882292
rowcount = None
22892293
if rowcount_raw is not None:
@@ -2292,7 +2296,7 @@ def _execute(self, sql: str, track_row_count: bool = False, **kwargs: t.Any) ->
22922296
except (TypeError, ValueError):
22932297
pass
22942298

2295-
track_execution_record(sql, rowcount)
2299+
QueryExecutionTracker.record_execution(sql, rowcount)
22962300

22972301
@contextlib.contextmanager
22982302
def temp_table(

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
SourceQuery,
2121
set_catalog,
2222
)
23-
from sqlmesh.core.execution_tracker import record_execution as track_execution_record
23+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
2424
from sqlmesh.core.node import IntervalUnit
2525
from sqlmesh.core.schema_diff import SchemaDiffer
2626
from sqlmesh.utils import optional_import
@@ -1091,7 +1091,7 @@ def _execute(
10911091
elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]:
10921092
num_rows = query_job.num_dml_affected_rows
10931093

1094-
track_execution_record(sql, num_rows)
1094+
QueryExecutionTracker.record_execution(sql, num_rows)
10951095

10961096
def _get_data_objects(
10971097
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None

sqlmesh/core/execution_tracker.py

Lines changed: 25 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from contextlib import contextmanager
66
from threading import local
77
from dataclasses import dataclass, field
8+
from sqlmesh.utils.errors import SQLMeshError
89

910

1011
@dataclass
@@ -27,7 +28,6 @@ class QueryExecutionContext:
2728
queries_executed: t.List[t.Tuple[str, t.Optional[int], float]] = field(default_factory=list)
2829

2930
def add_execution(self, sql: str, row_count: t.Optional[int]) -> None:
30-
"""Record a single query execution."""
3131
if row_count is not None and row_count >= 0:
3232
self.total_rows_processed += row_count
3333
self.query_count += 1
@@ -46,28 +46,41 @@ def get_execution_stats(self) -> t.Dict[str, t.Any]:
4646

4747
class QueryExecutionTracker:
4848
"""
49-
Thread-local context manager for snapshot evaluation execution statistics, such as
49+
Thread-local context manager for snapshot execution statistics, such as
5050
rows processed.
5151
"""
5252

5353
_thread_local = local()
54+
_contexts: t.Dict[str, QueryExecutionContext] = {}
5455

5556
@classmethod
56-
def get_execution_context(cls) -> t.Optional[QueryExecutionContext]:
57-
return getattr(cls._thread_local, "context", None)
57+
def get_execution_context(cls, snapshot_id_batch: str) -> t.Optional[QueryExecutionContext]:
58+
return cls._contexts.get(snapshot_id_batch)
5859

5960
@classmethod
6061
def is_tracking(cls) -> bool:
61-
return cls.get_execution_context() is not None
62+
return getattr(cls._thread_local, "context", None) is not None
6263

6364
@classmethod
6465
@contextmanager
65-
def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionContext]:
66+
def track_execution(
67+
cls, snapshot_id_batch: str, condition: bool = True
68+
) -> t.Iterator[t.Optional[QueryExecutionContext]]:
6669
"""
67-
Context manager for tracking snapshot evaluation execution statistics.
70+
Context manager for tracking snapshot execution statistics.
6871
"""
69-
context = QueryExecutionContext(id=snapshot_name_batch)
72+
if not condition:
73+
yield None
74+
return
75+
76+
if snapshot_id_batch in cls._contexts:
77+
raise SQLMeshError(
78+
f"Snapshot ID batch {snapshot_id_batch} execution has already been tracked. Each snapshot should only be tracked once."
79+
)
80+
81+
context = QueryExecutionContext(id=snapshot_id_batch)
7082
cls._thread_local.context = context
83+
cls._contexts[snapshot_id_batch] = context
7184
try:
7285
yield context
7386
finally:
@@ -76,67 +89,12 @@ def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionC
7689

7790
@classmethod
7891
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
79-
context = cls.get_execution_context()
92+
context = getattr(cls._thread_local, "context", None)
8093
if context is not None:
8194
context.add_execution(sql, row_count)
8295

8396
@classmethod
84-
def get_execution_stats(cls) -> t.Optional[t.Dict[str, t.Any]]:
85-
context = cls.get_execution_context()
86-
return context.get_execution_stats() if context else None
87-
88-
89-
class SeedExecutionTracker:
90-
_seed_contexts: t.Dict[str, QueryExecutionContext] = {}
91-
_thread_local = local()
92-
93-
@classmethod
94-
@contextmanager
95-
def track_execution(cls, model_name: str) -> t.Iterator[QueryExecutionContext]:
96-
"""
97-
Context manager for tracking seed creation execution statistics.
98-
"""
99-
context = QueryExecutionContext(id=model_name)
100-
cls._seed_contexts[model_name] = context
101-
cls._thread_local.seed_id = model_name
102-
103-
try:
104-
yield context
105-
finally:
106-
if hasattr(cls._thread_local, "seed_id"):
107-
delattr(cls._thread_local, "seed_id")
108-
109-
@classmethod
110-
def get_and_clear_seed_stats(cls, model_name: str) -> t.Optional[t.Dict[str, t.Any]]:
111-
context = cls._seed_contexts.pop(model_name, None)
97+
def get_execution_stats(cls, snapshot_id_batch: str) -> t.Optional[t.Dict[str, t.Any]]:
98+
context = cls.get_execution_context(snapshot_id_batch)
99+
cls._contexts.pop(snapshot_id_batch, None)
112100
return context.get_execution_stats() if context else None
113-
114-
@classmethod
115-
def clear_all_seed_stats(cls) -> None:
116-
"""Clear all remaining seed stats. Used for cleanup after evaluation completes."""
117-
cls._seed_contexts.clear()
118-
119-
@classmethod
120-
def is_tracking(cls) -> bool:
121-
return hasattr(cls._thread_local, "seed_id")
122-
123-
@classmethod
124-
def record_execution(cls, sql: str, row_count: t.Optional[int]) -> None:
125-
seed_id = getattr(cls._thread_local, "seed_id", None)
126-
if seed_id:
127-
context = cls._seed_contexts.get(seed_id)
128-
if context is not None:
129-
context.add_execution(sql, row_count)
130-
131-
132-
def record_execution(sql: str, row_count: t.Optional[int]) -> None:
133-
"""
134-
Record execution statistics for a single SQL statement.
135-
136-
Automatically infers which tracker is active based on the current thread.
137-
"""
138-
if SeedExecutionTracker.is_tracking():
139-
SeedExecutionTracker.record_execution(sql, row_count)
140-
return
141-
if QueryExecutionTracker.is_tracking():
142-
QueryExecutionTracker.record_execution(sql, row_count)

sqlmesh/core/scheduler.py

Lines changed: 52 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlmesh.core import constants as c
88
from sqlmesh.core.console import Console, get_console
99
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
10-
from sqlmesh.core.execution_tracker import QueryExecutionTracker, SeedExecutionTracker
10+
from sqlmesh.core.execution_tracker import QueryExecutionTracker
1111
from sqlmesh.core.macros import RuntimeStage
1212
from sqlmesh.core.model.definition import AuditResult
1313
from sqlmesh.core.node import IntervalUnit
@@ -427,69 +427,59 @@ def evaluate_node(node: SchedulingUnit) -> None:
427427
return
428428
snapshot = self.snapshots_by_name[snapshot_name]
429429

430-
with QueryExecutionTracker.track_execution(
431-
f"{snapshot.name}_{batch_idx}"
432-
) as execution_context:
433-
self.console.start_snapshot_evaluation_progress(snapshot)
434-
435-
execution_start_ts = now_timestamp()
436-
evaluation_duration_ms: t.Optional[int] = None
437-
438-
audit_results: t.List[AuditResult] = []
439-
try:
440-
assert execution_time # mypy
441-
assert deployability_index # mypy
442-
443-
if audit_only:
444-
audit_results = self._audit_snapshot(
445-
snapshot=snapshot,
446-
environment_naming_info=environment_naming_info,
447-
deployability_index=deployability_index,
448-
snapshots=self.snapshots_by_name,
449-
start=start,
450-
end=end,
451-
execution_time=execution_time,
452-
)
453-
else:
454-
audit_results = self.evaluate(
455-
snapshot=snapshot,
456-
environment_naming_info=environment_naming_info,
457-
start=start,
458-
end=end,
459-
execution_time=execution_time,
460-
deployability_index=deployability_index,
461-
batch_index=batch_idx,
462-
)
463-
464-
evaluation_duration_ms = now_timestamp() - execution_start_ts
465-
finally:
466-
num_audits = len(audit_results)
467-
num_audits_failed = sum(1 for result in audit_results if result.count)
468-
469-
rows_processed = None
470-
if snapshot.is_seed:
471-
# seed stats are tracked in SeedStrategy.create by model name, not snapshot name
472-
seed_stats = SeedExecutionTracker.get_and_clear_seed_stats(
473-
snapshot.model.name
474-
)
475-
rows_processed = (
476-
seed_stats.get("total_rows_processed") if seed_stats else None
477-
)
478-
else:
479-
rows_processed = (
480-
execution_context.total_rows_processed if execution_context else None
481-
)
482-
483-
self.console.update_snapshot_evaluation_progress(
484-
snapshot,
485-
batched_intervals[snapshot][batch_idx],
486-
batch_idx,
487-
evaluation_duration_ms,
488-
num_audits - num_audits_failed,
489-
num_audits_failed,
490-
rows_processed=rows_processed,
430+
self.console.start_snapshot_evaluation_progress(snapshot)
431+
432+
execution_start_ts = now_timestamp()
433+
evaluation_duration_ms: t.Optional[int] = None
434+
435+
audit_results: t.List[AuditResult] = []
436+
try:
437+
assert execution_time # mypy
438+
assert deployability_index # mypy
439+
440+
if audit_only:
441+
audit_results = self._audit_snapshot(
442+
snapshot=snapshot,
443+
environment_naming_info=environment_naming_info,
444+
deployability_index=deployability_index,
445+
snapshots=self.snapshots_by_name,
446+
start=start,
447+
end=end,
448+
execution_time=execution_time,
449+
)
450+
else:
451+
audit_results = self.evaluate(
452+
snapshot=snapshot,
453+
environment_naming_info=environment_naming_info,
454+
start=start,
455+
end=end,
456+
execution_time=execution_time,
457+
deployability_index=deployability_index,
458+
batch_index=batch_idx,
491459
)
492460

461+
evaluation_duration_ms = now_timestamp() - execution_start_ts
462+
finally:
463+
num_audits = len(audit_results)
464+
num_audits_failed = sum(1 for result in audit_results if result.count)
465+
466+
execution_stats = QueryExecutionTracker.get_execution_stats(
467+
f"{snapshot.snapshot_id}_{batch_idx}"
468+
)
469+
rows_processed = (
470+
execution_stats["total_rows_processed"] if execution_stats else None
471+
)
472+
473+
self.console.update_snapshot_evaluation_progress(
474+
snapshot,
475+
batched_intervals[snapshot][batch_idx],
476+
batch_idx,
477+
evaluation_duration_ms,
478+
num_audits - num_audits_failed,
479+
num_audits_failed,
480+
rows_processed=rows_processed,
481+
)
482+
493483
try:
494484
with self.snapshot_evaluator.concurrent_context():
495485
errors, skipped_intervals = concurrent_apply_to_dag(
@@ -529,9 +519,6 @@ def evaluate_node(node: SchedulingUnit) -> None:
529519

530520
self.state_sync.recycle()
531521

532-
# Clean up any remaining seed execution stats
533-
SeedExecutionTracker.clear_all_seed_stats()
534-
535522
def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]:
536523
"""Builds a DAG of snapshot intervals to be evaluated.
537524

0 commit comments

Comments
 (0)