From d184b2a3ef5e3bbe27f0e8bcc2b41e427cd3d6db Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 29 Aug 2025 08:13:42 -0700 Subject: [PATCH] Fix: Regression in WAP support --- sqlmesh/core/engine_adapter/spark.py | 9 + sqlmesh/core/snapshot/evaluator.py | 39 ++-- tests/core/engine_adapter/test_spark.py | 16 ++ tests/core/test_snapshot_evaluator.py | 254 +++++++++++++++++++++++- 4 files changed, 298 insertions(+), 20 deletions(-) diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 412e01f5bb..7d6a4d969b 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -402,6 +402,15 @@ def get_current_database(self) -> str: return self.spark.catalog.currentDatabase() return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore + def get_data_object(self, target_name: TableName) -> t.Optional[DataObject]: + target_table = exp.to_table(target_name) + if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith( + f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}" + ): + # Exclude the branch name + target_table.set("this", target_table.this.this) + return super().get_data_object(target_table) + def create_state_table( self, table_name: str, diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 87a6d15c42..c53c0a88db 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -651,7 +651,7 @@ def audit( snapshot.snapshot_id, wap_id, ) - self._wap_publish_snapshot(snapshot, wap_id, deployability_index) + self.wap_publish_snapshot(snapshot, wap_id, deployability_index) return results @@ -806,8 +806,10 @@ def _evaluate_snapshot( } wap_id: t.Optional[str] = None - if snapshot.is_materialized and ( - model.wap_supported or adapter.wap_supported(target_table_name) + if ( + snapshot.is_materialized + and target_table_exists + and (model.wap_supported or adapter.wap_supported(target_table_name)) ): wap_id = random_id()[0:8] logger.info("Using WAP ID '%s' for snapshot %s", wap_id, snapshot.snapshot_id) @@ -823,6 +825,7 @@ def _evaluate_snapshot( create_render_kwargs=create_render_kwargs, rendered_physical_properties=rendered_physical_properties, deployability_index=deployability_index, + target_table_name=target_table_name, is_first_insert=is_first_insert, batch_index=batch_index, ) @@ -896,6 +899,17 @@ def create_snapshot( if on_complete is not None: on_complete(snapshot) + def wap_publish_snapshot( + self, + snapshot: Snapshot, + wap_id: str, + deployability_index: t.Optional[DeployabilityIndex], + ) -> None: + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) + adapter = self.get_adapter(snapshot.model_gateway) + adapter.wap_publish(table_name, wap_id) + def _render_and_insert_snapshot( self, start: TimeLike, @@ -907,6 +921,7 @@ def _render_and_insert_snapshot( create_render_kwargs: t.Dict[str, t.Any], rendered_physical_properties: t.Dict[str, exp.Expression], deployability_index: DeployabilityIndex, + target_table_name: str, is_first_insert: bool, batch_index: int, ) -> None: @@ -916,7 +931,6 @@ def _render_and_insert_snapshot( logger.info("Inserting data for snapshot %s", snapshot.snapshot_id) model = snapshot.model - table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) adapter = self.get_adapter(model.gateway) evaluation_strategy = _evaluation_strategy(snapshot, adapter) @@ -930,7 +944,7 @@ def _render_and_insert_snapshot( def apply(query_or_df: QueryOrDF, index: int = 0) -> None: if index > 0: evaluation_strategy.append( - table_name=table_name, + table_name=target_table_name, query_or_df=query_or_df, model=snapshot.model, snapshot=snapshot, @@ -948,10 +962,10 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: "Inserting batch (%s, %s) into %s'", time_like_to_str(start), time_like_to_str(end), - table_name, + target_table_name, ) evaluation_strategy.insert( - table_name=table_name, + table_name=target_table_name, query_or_df=query_or_df, is_first_insert=is_first_insert, model=snapshot.model, @@ -1278,17 +1292,6 @@ def _cleanup_snapshot( if on_complete is not None: on_complete(table_name) - def _wap_publish_snapshot( - self, - snapshot: Snapshot, - wap_id: str, - deployability_index: t.Optional[DeployabilityIndex], - ) -> None: - deployability_index = deployability_index or DeployabilityIndex.all_deployable() - table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) - adapter = self.get_adapter(snapshot.model_gateway) - adapter.wap_publish(table_name, wap_id) - def _audit( self, audit: Audit, diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index 2e4f6ae2a0..f1929639a2 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -1080,3 +1080,19 @@ def test_table_format(adapter: SparkEngineAdapter, mocker: MockerFixture): "CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) USING ICEBERG", "CREATE TABLE IF NOT EXISTS `test_table` USING ICEBERG TBLPROPERTIES ('write.format.default'='orc') AS SELECT CAST(`cola` AS TIMESTAMP) AS `cola`, CAST(`colb` AS STRING) AS `colb`, CAST(`colc` AS STRING) AS `colc` FROM (SELECT CAST(1 AS TIMESTAMP) AS `cola`, CAST(2 AS STRING) AS `colb`, 'foo' AS `colc`) AS `_subquery`", ] + + +def test_get_data_object_wap_branch(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SparkEngineAdapter, patch_get_data_objects=False) + mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + + table = exp.to_table( + "`catalog`.`sqlmesh__test`.`test__test_view__630979748`.`branch_wap_472234d7`", + dialect="spark", + ) + adapter.get_data_object(table) + + adapter._get_data_objects.assert_called_once_with( + d.schema_("sqlmesh__test", "catalog"), + {"test__test_view__630979748"}, + ) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 3b72a14f5f..60908ed7c4 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing as t from typing_extensions import Self -from unittest.mock import call, patch +from unittest.mock import call, patch, Mock import re import logging import pytest @@ -2907,7 +2907,7 @@ def test_standalone_audit(mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock.session.assert_not_called() -def test_audit_wap(adapter_mock, make_snapshot): +def test_audit_wap(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: evaluator = SnapshotEvaluator(adapter_mock) custom_audit = ModelAudit( @@ -4331,3 +4331,253 @@ def test_multiple_engine_virtual_layer(snapshot: Snapshot, adapters, make_snapsh "test_schema__test_env.test_model", cascade=False, ) + + +def test_wap_basic( + adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot], mocker: MockerFixture +) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = True + + expected_wap_table = "test_schema.test_table.branch_wap_12345678" + adapter_mock.wap_prepare.return_value = expected_wap_table + + wap_id = evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, # Use parameter to control table existence + ) + + assert wap_id is not None + assert len(wap_id) == 8 + + expected_table_name = snapshot.table_name() + adapter_mock.wap_prepare.assert_called_once_with(expected_table_name, wap_id) + adapter_mock.replace_query.assert_called_once_with( + expected_wap_table, + mocker.ANY, + table_format=mocker.ANY, + storage_format=mocker.ANY, + partitioned_by=mocker.ANY, + partition_interval_unit=mocker.ANY, + clustered_by=mocker.ANY, + table_properties=mocker.ANY, + table_description=mocker.ANY, + column_descriptions=mocker.ANY, + target_columns_to_types=mocker.ANY, + source_columns=mocker.ANY, + ) + + +def test_wap_model_wap_supported( + adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot], mocker: MockerFixture +) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + storage_format="iceberg", # Model supports WAP via iceberg format + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = False + + expected_wap_table = "test_schema.test_table.branch_wap_12345678" + adapter_mock.wap_prepare.return_value = expected_wap_table + + wap_id = evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, # Use parameter to control table existence + ) + assert wap_id is not None + + expected_table_name = snapshot.table_name() + adapter_mock.wap_prepare.assert_called_once_with(expected_table_name, wap_id) + adapter_mock.replace_query.assert_called_once_with( + expected_wap_table, + mocker.ANY, + table_format=mocker.ANY, + storage_format=mocker.ANY, + partitioned_by=mocker.ANY, + partition_interval_unit=mocker.ANY, + clustered_by=mocker.ANY, + table_properties=mocker.ANY, + table_description=mocker.ANY, + column_descriptions=mocker.ANY, + target_columns_to_types=mocker.ANY, + source_columns=mocker.ANY, + ) + + +def test_wap_no_wap_support(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = False + + wap_id = evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, + ) + + assert wap_id is None + adapter_mock.wap_prepare.assert_not_called() + + +def test_wap_non_materialized_snapshot( + adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot] +) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=ViewKind(), # View kind is not materialized + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = True + + wap_id = evaluator.evaluate( + snapshot, start="2020-01-01", end="2020-01-01", execution_time="2020-01-01", snapshots={} + ) + + assert wap_id is None + adapter_mock.wap_prepare.assert_not_called() + + +def test_wap_publish_snapshot(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + wap_id = "test_wap_id" + deployability_index = DeployabilityIndex.all_deployable() + + evaluator.wap_publish_snapshot(snapshot, wap_id, deployability_index) + + expected_table_name = snapshot.table_name(is_deployable=True) + adapter_mock.wap_publish.assert_called_once_with(expected_table_name, wap_id) + + +def test_wap_during_audit(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + custom_audit = ModelAudit( + name="custom_audit", + query="SELECT * FROM test_schema.test_table WHERE invalid_condition", + ) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + audits=[ + ("not_null", {"columns": exp.to_column("a")}), + ("custom_audit", {}), + ], + audit_definitions={custom_audit.name: custom_audit}, + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + wap_id = "test_wap_id" + expected_wap_table_name = f"test_schema.test_table.branch_wap_{wap_id}" + adapter_mock.wap_table_name.return_value = expected_wap_table_name + adapter_mock.fetchone.return_value = (0,) + + results = evaluator.audit(snapshot, snapshots={}, wap_id=wap_id) + + assert len(results) == 2 + + adapter_mock.wap_table_name.assert_called_once_with(snapshot.table_name(), wap_id) + adapter_mock.wap_publish.assert_called_once_with(snapshot.table_name(), wap_id) + + +def test_wap_prepare_failure(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = True + + adapter_mock.wap_prepare.side_effect = Exception("WAP prepare failed") + + with pytest.raises(Exception, match="WAP prepare failed"): + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, + ) + + +def test_wap_publish_failure(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + """Test error handling when WAP publish fails.""" + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + audits=[("not_null", {"columns": exp.to_column("a")})], + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + wap_id = "test_wap_id" + expected_wap_table_name = f"test_schema.test_table.branch_wap_{wap_id}" + adapter_mock.wap_table_name.return_value = expected_wap_table_name + adapter_mock.fetchone.return_value = (0,) + + # Mock WAP publish to raise an exception + adapter_mock.wap_publish.side_effect = Exception("WAP publish failed") + + # Execute audit with WAP ID and expect it to raise the exception + with pytest.raises(Exception, match="WAP publish failed"): + evaluator.audit(snapshot, snapshots={}, wap_id=wap_id)