Skip to content

Commit 1bd8ae9

Browse files
committed
Refactor to fetch all jobs across all threads
1 parent 88a3501 commit 1bd8ae9

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

sqlmesh/core/engine_adapter/bigquery.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -188,22 +188,27 @@ def query_factory() -> Query:
188188
]
189189

190190
def close(self) -> t.Any:
191-
query_job = self._query_job
192-
if not query_job:
193-
return super().close()
194-
195-
# Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted)
196-
try:
197-
if not self._db_call(query_job.done):
198-
self._db_call(query_job.cancel)
199-
except Exception as ex:
200-
logger.debug(
201-
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
202-
query_job.project,
203-
query_job.location,
204-
query_job.job_id,
205-
str(ex),
206-
)
191+
# Cancel all pending query jobs across all threads
192+
all_query_jobs = self._connection_pool.get_all_attributes("query_job")
193+
for query_job in all_query_jobs:
194+
if query_job:
195+
try:
196+
if not self._db_call(query_job.done):
197+
self._db_call(query_job.cancel)
198+
logger.debug(
199+
"Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s",
200+
query_job.project,
201+
query_job.location,
202+
query_job.job_id,
203+
)
204+
except Exception as ex:
205+
logger.debug(
206+
"Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s",
207+
query_job.project,
208+
query_job.location,
209+
query_job.job_id,
210+
str(ex),
211+
)
207212

208213
return super().close()
209214

sqlmesh/utils/connection_pool.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ def set_attribute(self, key: str, value: t.Any) -> None:
4848
value: Attribute value.
4949
"""
5050

51+
@abc.abstractmethod
52+
def get_all_attributes(self, key: str) -> t.List[t.Any]:
53+
"""Returns all attributes with the given key across all connections/threads.
54+
55+
Args:
56+
key: Attribute key.
57+
58+
Returns:
59+
List of attribute values from all connections/threads.
60+
"""
61+
5162
@abc.abstractmethod
5263
def begin(self) -> None:
5364
"""Starts a new transaction."""
@@ -142,6 +153,14 @@ def set_attribute(self, key: str, value: t.Any) -> None:
142153
thread_id = get_ident()
143154
self._thread_attributes[thread_id][key] = value
144155

156+
def get_all_attributes(self, key: str) -> t.List[t.Any]:
157+
"""Returns all attributes with the given key across all threads."""
158+
return [
159+
thread_attrs[key]
160+
for thread_attrs in self._thread_attributes.values()
161+
if key in thread_attrs
162+
]
163+
145164
def begin(self) -> None:
146165
self._do_begin()
147166
with self._thread_transactions_lock:
@@ -282,6 +301,11 @@ def get_attribute(self, key: str) -> t.Optional[t.Any]:
282301
def set_attribute(self, key: str, value: t.Any) -> None:
283302
self._attributes[key] = value
284303

304+
def get_all_attributes(self, key: str) -> t.List[t.Any]:
305+
"""Returns all attributes with the given key (single-threaded pool has at most one)."""
306+
value = self._attributes.get(key)
307+
return [value] if value is not None else []
308+
285309
def begin(self) -> None:
286310
self._do_begin()
287311
self._is_transaction_active = True

0 commit comments

Comments
 (0)