|
33 | 33 | from google.cloud import bigquery |
34 | 34 | from google.cloud.bigquery import StandardSqlDataType |
35 | 35 | from google.cloud.bigquery.client import Client as BigQueryClient |
| 36 | + from google.cloud.bigquery.job import QueryJob |
36 | 37 | from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult |
37 | 38 | from google.cloud.bigquery.table import Table as BigQueryTable |
38 | 39 |
|
@@ -187,21 +188,23 @@ def query_factory() -> Query: |
187 | 188 | ] |
188 | 189 |
|
189 | 190 | def close(self) -> t.Any: |
190 | | - # Cancel all pending query jobs to avoid them becoming orphan, e.g., due to interrupts |
191 | | - for query_job in self._query_jobs: |
192 | | - try: |
193 | | - if not self._db_call(query_job.done): |
194 | | - self._db_call(query_job.cancel) |
195 | | - except Exception as ex: |
196 | | - logger.debug( |
197 | | - "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
198 | | - self._query_job.project, |
199 | | - self._query_job.location, |
200 | | - self._query_job.job_id, |
201 | | - str(ex), |
202 | | - ) |
| 191 | + query_job = self._query_job |
| 192 | + if not query_job: |
| 193 | + return super().close() |
| 194 | + |
| 195 | + # Cancel the last submitted query job if it's still pending, to avoid it becoming orphan (e.g., if interrupted) |
| 196 | + try: |
| 197 | + if not self._db_call(query_job.done): |
| 198 | + self._db_call(query_job.cancel) |
| 199 | + except Exception as ex: |
| 200 | + logger.debug( |
| 201 | + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", |
| 202 | + query_job.project, |
| 203 | + query_job.location, |
| 204 | + query_job.job_id, |
| 205 | + str(ex), |
| 206 | + ) |
203 | 207 |
|
204 | | - self._query_jobs.clear() |
205 | 208 | return super().close() |
206 | 209 |
|
207 | 210 | def _begin_session(self, properties: SessionProperties) -> None: |
@@ -336,7 +339,10 @@ def create_mapping_schema( |
336 | 339 | if len(table.parts) == 3 and "." in table.name: |
337 | 340 | # The client's `get_table` method can't handle paths with >3 identifiers |
338 | 341 | self.execute(exp.select("*").from_(table).limit(0)) |
339 | | - query_results = self._query_job._query_results |
| 342 | + query_job = self._query_job |
| 343 | + assert query_job is not None |
| 344 | + |
| 345 | + query_results = query_job._query_results |
340 | 346 | columns = create_mapping_schema(query_results.schema) |
341 | 347 | else: |
342 | 348 | bq_table = self._get_table(table) |
@@ -735,7 +741,9 @@ def _fetch_native_df( |
735 | 741 | self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False |
736 | 742 | ) -> DF: |
737 | 743 | self.execute(query, quote_identifiers=quote_identifiers) |
738 | | - return self._query_job.to_dataframe() |
| 744 | + query_job = self._query_job |
| 745 | + assert query_job is not None |
| 746 | + return query_job.to_dataframe() |
739 | 747 |
|
740 | 748 | def _create_column_comments( |
741 | 749 | self, |
@@ -1039,24 +1047,23 @@ def _execute( |
1039 | 1047 | job_config=job_config, |
1040 | 1048 | timeout=self._extra_config.get("job_creation_timeout_seconds"), |
1041 | 1049 | ) |
1042 | | - self._query_jobs.add(self._query_job) |
| 1050 | + query_job = self._query_job |
| 1051 | + assert query_job is not None |
1043 | 1052 |
|
1044 | 1053 | logger.debug( |
1045 | 1054 | "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", |
1046 | | - self._query_job.project, |
1047 | | - self._query_job.location, |
1048 | | - self._query_job.job_id, |
| 1055 | + query_job.project, |
| 1056 | + query_job.location, |
| 1057 | + query_job.job_id, |
1049 | 1058 | ) |
1050 | 1059 |
|
1051 | 1060 | results = self._db_call( |
1052 | | - self._query_job.result, |
| 1061 | + query_job.result, |
1053 | 1062 | timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore |
1054 | 1063 | ) |
1055 | 1064 |
|
1056 | | - self._query_jobs.remove(self._query_job) |
1057 | | - |
1058 | 1065 | self._query_data = iter(results) if results.total_rows else iter([]) |
1059 | | - query_results = self._query_job._query_results |
| 1066 | + query_results = query_job._query_results |
1060 | 1067 | self.cursor._set_rowcount(query_results) |
1061 | 1068 | self.cursor._set_description(query_results.schema) |
1062 | 1069 |
|
@@ -1220,32 +1227,23 @@ def _query_data(self) -> t.Any: |
1220 | 1227 |
|
1221 | 1228 | @_query_data.setter |
1222 | 1229 | def _query_data(self, value: t.Any) -> None: |
1223 | | - return self._connection_pool.set_attribute("query_data", value) |
1224 | | - |
1225 | | - @property |
1226 | | - def _query_jobs(self) -> t.Any: |
1227 | | - query_jobs = self._connection_pool.get_attribute("query_jobs") |
1228 | | - if not isinstance(query_jobs, set): |
1229 | | - query_jobs = set() |
1230 | | - self._connection_pool.set_attribute("query_jobs", query_jobs) |
1231 | | - |
1232 | | - return query_jobs |
| 1230 | + self._connection_pool.set_attribute("query_data", value) |
1233 | 1231 |
|
1234 | 1232 | @property |
1235 | | - def _query_job(self) -> t.Any: |
| 1233 | + def _query_job(self) -> t.Optional[QueryJob]: |
1236 | 1234 | return self._connection_pool.get_attribute("query_job") |
1237 | 1235 |
|
1238 | 1236 | @_query_job.setter |
1239 | 1237 | def _query_job(self, value: t.Any) -> None: |
1240 | | - return self._connection_pool.set_attribute("query_job", value) |
| 1238 | + self._connection_pool.set_attribute("query_job", value) |
1241 | 1239 |
|
1242 | 1240 | @property |
1243 | 1241 | def _session_id(self) -> t.Any: |
1244 | 1242 | return self._connection_pool.get_attribute("session_id") |
1245 | 1243 |
|
1246 | 1244 | @_session_id.setter |
1247 | 1245 | def _session_id(self, value: t.Any) -> None: |
1248 | | - return self._connection_pool.set_attribute("session_id", value) |
| 1246 | + self._connection_pool.set_attribute("session_id", value) |
1249 | 1247 |
|
1250 | 1248 |
|
1251 | 1249 | class _ErrorCounter: |
|
0 commit comments