55from contextlib import contextmanager
66from threading import local
77from dataclasses import dataclass , field
8+ from sqlmesh .utils .errors import SQLMeshError
89
910
1011@dataclass
@@ -27,7 +28,6 @@ class QueryExecutionContext:
2728 queries_executed : t .List [t .Tuple [str , t .Optional [int ], float ]] = field (default_factory = list )
2829
2930 def add_execution (self , sql : str , row_count : t .Optional [int ]) -> None :
30- """Record a single query execution."""
3131 if row_count is not None and row_count >= 0 :
3232 self .total_rows_processed += row_count
3333 self .query_count += 1
@@ -46,28 +46,41 @@ def get_execution_stats(self) -> t.Dict[str, t.Any]:
4646
4747class QueryExecutionTracker :
4848 """
49- Thread-local context manager for snapshot evaluation execution statistics, such as
49+ Thread-local context manager for snapshot execution statistics, such as
5050 rows processed.
5151 """
5252
5353 _thread_local = local ()
54+ _contexts : t .Dict [str , QueryExecutionContext ] = {}
5455
5556 @classmethod
56- def get_execution_context (cls ) -> t .Optional [QueryExecutionContext ]:
57- return getattr ( cls ._thread_local , "context" , None )
57+ def get_execution_context (cls , snapshot_id_batch : str ) -> t .Optional [QueryExecutionContext ]:
58+ return cls ._contexts . get ( snapshot_id_batch )
5859
5960 @classmethod
6061 def is_tracking (cls ) -> bool :
61- return cls .get_execution_context ( ) is not None
62+ return getattr ( cls ._thread_local , "context" , None ) is not None
6263
6364 @classmethod
6465 @contextmanager
65- def track_execution (cls , snapshot_name_batch : str ) -> t .Iterator [QueryExecutionContext ]:
66+ def track_execution (
67+ cls , snapshot_id_batch : str , condition : bool = True
68+ ) -> t .Iterator [t .Optional [QueryExecutionContext ]]:
6669 """
67- Context manager for tracking snapshot evaluation execution statistics.
70+ Context manager for tracking snapshot execution statistics.
6871 """
69- context = QueryExecutionContext (id = snapshot_name_batch )
72+ if not condition :
73+ yield None
74+ return
75+
76+ if snapshot_id_batch in cls ._contexts :
77+ raise SQLMeshError (
78+ f"Snapshot ID batch { snapshot_id_batch } execution has already been tracked. Each snapshot should only be tracked once."
79+ )
80+
81+ context = QueryExecutionContext (id = snapshot_id_batch )
7082 cls ._thread_local .context = context
83+ cls ._contexts [snapshot_id_batch ] = context
7184 try :
7285 yield context
7386 finally :
@@ -76,67 +89,12 @@ def track_execution(cls, snapshot_name_batch: str) -> t.Iterator[QueryExecutionC
7689
7790 @classmethod
7891 def record_execution (cls , sql : str , row_count : t .Optional [int ]) -> None :
79- context = cls .get_execution_context ( )
92+ context = getattr ( cls ._thread_local , "context" , None )
8093 if context is not None :
8194 context .add_execution (sql , row_count )
8295
8396 @classmethod
84- def get_execution_stats (cls ) -> t .Optional [t .Dict [str , t .Any ]]:
85- context = cls .get_execution_context ()
86- return context .get_execution_stats () if context else None
87-
88-
89- class SeedExecutionTracker :
90- _seed_contexts : t .Dict [str , QueryExecutionContext ] = {}
91- _thread_local = local ()
92-
93- @classmethod
94- @contextmanager
95- def track_execution (cls , model_name : str ) -> t .Iterator [QueryExecutionContext ]:
96- """
97- Context manager for tracking seed creation execution statistics.
98- """
99- context = QueryExecutionContext (id = model_name )
100- cls ._seed_contexts [model_name ] = context
101- cls ._thread_local .seed_id = model_name
102-
103- try :
104- yield context
105- finally :
106- if hasattr (cls ._thread_local , "seed_id" ):
107- delattr (cls ._thread_local , "seed_id" )
108-
109- @classmethod
110- def get_and_clear_seed_stats (cls , model_name : str ) -> t .Optional [t .Dict [str , t .Any ]]:
111- context = cls ._seed_contexts .pop (model_name , None )
97+ def get_execution_stats (cls , snapshot_id_batch : str ) -> t .Optional [t .Dict [str , t .Any ]]:
98+ context = cls .get_execution_context (snapshot_id_batch )
99+ cls ._contexts .pop (snapshot_id_batch , None )
112100 return context .get_execution_stats () if context else None
113-
114- @classmethod
115- def clear_all_seed_stats (cls ) -> None :
116- """Clear all remaining seed stats. Used for cleanup after evaluation completes."""
117- cls ._seed_contexts .clear ()
118-
119- @classmethod
120- def is_tracking (cls ) -> bool :
121- return hasattr (cls ._thread_local , "seed_id" )
122-
123- @classmethod
124- def record_execution (cls , sql : str , row_count : t .Optional [int ]) -> None :
125- seed_id = getattr (cls ._thread_local , "seed_id" , None )
126- if seed_id :
127- context = cls ._seed_contexts .get (seed_id )
128- if context is not None :
129- context .add_execution (sql , row_count )
130-
131-
132- def record_execution (sql : str , row_count : t .Optional [int ]) -> None :
133- """
134- Record execution statistics for a single SQL statement.
135-
136- Automatically infers which tracker is active based on the current thread.
137- """
138- if SeedExecutionTracker .is_tracking ():
139- SeedExecutionTracker .record_execution (sql , row_count )
140- return
141- if QueryExecutionTracker .is_tracking ():
142- QueryExecutionTracker .record_execution (sql , row_count )
0 commit comments