diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 648594d0c0..adb4aa0d19 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -134,6 +134,10 @@ def _job_params(self) -> t.Dict[str, t.Any]: } if self._extra_config.get("maximum_bytes_billed"): params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed") + if self.correlation_id: + # BigQuery label keys must be lowercase + key = self.correlation_id.job_type.value.lower() + params["labels"] = {key: self.correlation_id.job_id} return params @property @@ -204,6 +208,11 @@ def _begin_session(self, properties: SessionProperties) -> None: "Invalid value for `session_properties.query_label`. Must be an array or tuple." ) + if self.correlation_id: + parsed_query_label.append( + (self.correlation_id.job_type.value.lower(), self.correlation_id.job_id) + ) + if parsed_query_label: query_label_str = ",".join([":".join(label) for label in parsed_query_label]) query = f'SET @@query_label = "{query_label_str}";SELECT 1;' diff --git a/tests/core/engine_adapter/integration/test_integration_bigquery.py b/tests/core/engine_adapter/integration/test_integration_bigquery.py index e974d79da2..c97c94d036 100644 --- a/tests/core/engine_adapter/integration/test_integration_bigquery.py +++ b/tests/core/engine_adapter/integration/test_integration_bigquery.py @@ -11,8 +11,9 @@ from sqlmesh.core.engine_adapter.shared import DataObject import sqlmesh.core.dialect as d from sqlmesh.core.model import SqlModel, load_sql_based_model -from sqlmesh.core.plan import Plan +from sqlmesh.core.plan import Plan, BuiltInPlanEvaluator from sqlmesh.core.table_diff import TableDiff +from sqlmesh.utils import CorrelationId from tests.core.engine_adapter.integration import TestContext from pytest import FixtureRequest from tests.core.engine_adapter.integration import ( @@ -447,3 +448,33 @@ def test_materialized_view_evaluation(ctx: TestContext, engine_adapter: BigQuery df = engine_adapter.fetchdf(f"SELECT * FROM {mview_name.sql(dialect=ctx.dialect)}") assert df["col"][0] == 2 + + +def test_correlation_id_in_job_labels(ctx: TestContext): + model_name = ctx.table("test") + + sqlmesh = ctx.create_context() + sqlmesh.upsert_model( + load_sql_based_model(d.parse(f"MODEL (name {model_name}, kind FULL); SELECT 1 AS col")) + ) + + # Create a plan evaluator and a plan to evaluate + plan_evaluator = BuiltInPlanEvaluator( + sqlmesh.state_sync, + sqlmesh.snapshot_evaluator, + sqlmesh.create_scheduler, + sqlmesh.default_catalog, + ) + plan: Plan = sqlmesh.plan_builder("prod", skip_tests=True).build() + + # Evaluate the plan and retrieve the plan evaluator's adapter + plan_evaluator.evaluate(plan.to_evaluatable()) + adapter = t.cast(BigQueryEngineAdapter, plan_evaluator.snapshot_evaluator.adapter) + + # Case 1: Ensure that the correlation id is set in the underlying adapter + assert adapter.correlation_id is not None + + # Case 2: Ensure that the correlation id is set in the job labels + labels = adapter._job_params.get("labels") + correlation_id = CorrelationId.from_plan_id(plan.plan_id) + assert labels == {correlation_id.job_type.value.lower(): correlation_id.job_id}