|
33 | 33 | from google.cloud import bigquery |
34 | 34 | from google.cloud.bigquery import StandardSqlDataType |
35 | 35 | from google.cloud.bigquery.client import Client as BigQueryClient |
| 36 | + from google.cloud.bigquery.job import QueryJob |
36 | 37 | from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult |
37 | 38 | from google.cloud.bigquery.table import Table as BigQueryTable |
38 | 39 |
|
@@ -187,21 +188,23 @@ def query_factory() -> Query: |
187 | 188 | ] |
188 | 189 |
|
189 | 190 | def close(self) -> t.Any: |
190 | | - # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts |
191 | | - for query_job in self._query_jobs: |
192 | | - try: |
193 | | - if not self._db_call(query_job.done): |
194 | | - self._db_call(query_job.cancel) |
195 | | - except Exception as ex: |
196 | | - logger.debug( |
197 | | - "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
198 | | - self._query_job.project, |
199 | | - self._query_job.location, |
200 | | - self._query_job.job_id, |
201 | | - str(ex), |
202 | | - ) |
| 191 | + query_job = self._query_job |
| 192 | + if not query_job: |
| 193 | + return super().close() |
| 194 | + |
| 195 | + # Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted) |
| 196 | + try: |
| 197 | + if not self._db_call(query_job.done): |
| 198 | + self._db_call(query_job.cancel) |
| 199 | + except Exception as ex: |
| 200 | + logger.debug( |
| 201 | + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
| 202 | + query_job.project, |
| 203 | + query_job.location, |
| 204 | + query_job.job_id, |
| 205 | + str(ex), |
| 206 | + ) |
203 | 207 |
|
204 | | - self._query_jobs.clear() |
205 | 208 | return super().close() |
206 | 209 |
|
207 | 210 | def _begin_session(self, properties: SessionProperties) -> None: |
@@ -336,7 +339,10 @@ def create_mapping_schema( |
336 | 339 | if len(table.parts) == 3 and "." in table.name: |
337 | 340 | # The client's `get_table` method can't handle paths with >3 identifiers |
338 | 341 | self.execute(exp.select("*").from_(table).limit(0)) |
339 | | - query_results = self._query_job._query_results |
| 342 | + query_job = self._query_job |
| 343 | + assert query_job is not None |
| 344 | + |
| 345 | + query_results = query_job._query_results |
340 | 346 | columns = create_mapping_schema(query_results.schema) |
341 | 347 | else: |
342 | 348 | bq_table = self._get_table(table) |
@@ -723,7 +729,9 @@ def _fetch_native_df( |
723 | 729 | self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False |
724 | 730 | ) -> DF: |
725 | 731 | self.execute(query, quote_identifiers=quote_identifiers) |
726 | | - return self._query_job.to_dataframe() |
| 732 | + query_job = self._query_job |
| 733 | + assert query_job is not None |
| 734 | + return query_job.to_dataframe() |
727 | 735 |
|
728 | 736 | def _create_column_comments( |
729 | 737 | self, |
@@ -1027,24 +1035,23 @@ def _execute( |
1027 | 1035 | job_config=job_config, |
1028 | 1036 | timeout=self._extra_config.get("job_creation_timeout_seconds"), |
1029 | 1037 | ) |
1030 | | - self._query_jobs.add(self._query_job) |
| 1038 | + query_job = self._query_job |
| 1039 | + assert query_job is not None |
1031 | 1040 |
|
1032 | 1041 | logger.debug( |
1033 | 1042 | "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", |
1034 | | - self._query_job.project, |
1035 | | - self._query_job.location, |
1036 | | - self._query_job.job_id, |
| 1043 | + query_job.project, |
| 1044 | + query_job.location, |
| 1045 | + query_job.job_id, |
1037 | 1046 | ) |
1038 | 1047 |
|
1039 | 1048 | results = self._db_call( |
1040 | | - self._query_job.result, |
| 1049 | + query_job.result, |
1041 | 1050 | timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore |
1042 | 1051 | ) |
1043 | 1052 |
|
1044 | | - self._query_jobs.remove(self._query_job) |
1045 | | - |
1046 | 1053 | self._query_data = iter(results) if results.total_rows else iter([]) |
1047 | | - query_results = self._query_job._query_results |
| 1054 | + query_results = query_job._query_results |
1048 | 1055 | self.cursor._set_rowcount(query_results) |
1049 | 1056 | self.cursor._set_description(query_results.schema) |
1050 | 1057 |
|
@@ -1208,32 +1215,23 @@ def _query_data(self) -> t.Any: |
1208 | 1215 |
|
1209 | 1216 | @_query_data.setter |
1210 | 1217 | def _query_data(self, value: t.Any) -> None: |
1211 | | - return self._connection_pool.set_attribute("query_data", value) |
1212 | | - |
1213 | | - @property |
1214 | | - def _query_jobs(self) -> t.Any: |
1215 | | - query_jobs = self._connection_pool.get_attribute("query_jobs") |
1216 | | - if not isinstance(query_jobs, set): |
1217 | | - query_jobs = set() |
1218 | | - self._connection_pool.set_attribute("query_jobs", query_jobs) |
1219 | | - |
1220 | | - return query_jobs |
| 1218 | + self._connection_pool.set_attribute("query_data", value) |
1221 | 1219 |
|
1222 | 1220 | @property |
1223 | | - def _query_job(self) -> t.Any: |
| 1221 | + def _query_job(self) -> t.Optional[QueryJob]: |
1224 | 1222 | return self._connection_pool.get_attribute("query_job") |
1225 | 1223 |
|
1226 | 1224 | @_query_job.setter |
1227 | 1225 | def _query_job(self, value: t.Any) -> None: |
1228 | | - return self._connection_pool.set_attribute("query_job", value) |
| 1226 | + self._connection_pool.set_attribute("query_job", value) |
1229 | 1227 |
|
1230 | 1228 | @property |
1231 | 1229 | def _session_id(self) -> t.Any: |
1232 | 1230 | return self._connection_pool.get_attribute("session_id") |
1233 | 1231 |
|
1234 | 1232 | @_session_id.setter |
1235 | 1233 | def _session_id(self, value: t.Any) -> None: |
1236 | | - return self._connection_pool.set_attribute("session_id", value) |
| 1234 | + self._connection_pool.set_attribute("session_id", value) |
1237 | 1235 |
|
1238 | 1236 |
|
1239 | 1237 | class _ErrorCounter: |
|
0 commit comments