diff --git a/docs/concepts/macros/macro_variables.md b/docs/concepts/macros/macro_variables.md index 858bf9f19d..a184f7d99f 100644 --- a/docs/concepts/macros/macro_variables.md +++ b/docs/concepts/macros/macro_variables.md @@ -132,7 +132,8 @@ SQLMesh provides additional predefined variables used to modify model behavior b * 'loading' - The project is being loaded into SQLMesh's runtime context. * 'creating' - The model tables are being created. * 'evaluating' - The model query logic is being evaluated. - * 'promoting' - The model is being promoted in the target environment (virtual layer update). + * 'promoting' - The model is being promoted in the target environment (view created during virtual layer update). + * 'demoting' - The model is being demoted in the target environment (view dropped during virtual layer update). * 'auditing' - The audit is being run. * 'testing' - The model query logic is being evaluated in the context of a unit test. * @gateway - A string value containing the name of the current [gateway](../../guides/connections.md). diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index 2c5ef3b2e8..ec5b2567f4 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -67,6 +67,7 @@ class RuntimeStage(Enum): CREATING = "creating" EVALUATING = "evaluating" PROMOTING = "promoting" + DEMOTING = "demoting" AUDITING = "auditing" TESTING = "testing" BEFORE_ALL = "before_all" diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 545a5e5494..9488b9bc91 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -340,9 +340,11 @@ def visit_virtual_layer_update_stage( ) if stage.demoted_environment_naming_info: self._demote_snapshots( - stage.demoted_snapshots, + [stage.all_snapshots[s.snapshot_id] for s in stage.demoted_snapshots], stage.demoted_environment_naming_info, + deployability_index=stage.deployability_index, on_complete=lambda s: self.console.update_promotion_progress(s, False), + snapshots=stage.all_snapshots, ) completed = True @@ -382,12 +384,23 @@ def _promote_snapshots( def _demote_snapshots( self, - target_snapshots: t.Iterable[SnapshotTableInfo], + target_snapshots: t.Iterable[Snapshot], environment_naming_info: EnvironmentNamingInfo, + snapshots: t.Dict[SnapshotId, Snapshot], + deployability_index: t.Optional[DeployabilityIndex] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: self.snapshot_evaluator.demote( - target_snapshots, environment_naming_info, on_complete=on_complete + target_snapshots, + environment_naming_info, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + default_catalog=self.default_catalog, + dialect=self.snapshot_evaluator.adapter.dialect, + ), + deployability_index=deployability_index, + on_complete=on_complete, ) def _restatement_intervals_across_all_environments( diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index 9913a87bd0..194177b0cf 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -361,11 +361,14 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: # Otherwise, unpause right after updatig the environment record. stages.append(UnpauseStage(promoted_snapshots=promoted_snapshots)) + full_demoted_snapshots = self.state_reader.get_snapshots( + s.snapshot_id for s in demoted_snapshots if s.snapshot_id not in snapshots + ) virtual_layer_update_stage = self._get_virtual_layer_update_stage( promoted_snapshots, demoted_snapshots, demoted_environment_naming_info, - snapshots, + snapshots | full_demoted_snapshots, deployability_index, ) if virtual_layer_update_stage: diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 641f216699..993860b527 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -276,8 +276,10 @@ def promote( def demote( self, - target_snapshots: t.Iterable[SnapshotInfoLike], + target_snapshots: t.Iterable[Snapshot], environment_naming_info: EnvironmentNamingInfo, + table_mapping: t.Optional[t.Dict[str, str]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: """Demotes the given collection of snapshots in the target environment by removing its view. @@ -290,7 +292,13 @@ def demote( with self.concurrent_context(): concurrent_apply_to_snapshots( target_snapshots, - lambda s: self._demote_snapshot(s, environment_naming_info, on_complete), + lambda s: self._demote_snapshot( + s, + environment_naming_info, + deployability_index=deployability_index, + on_complete=on_complete, + table_mapping=table_mapping, + ), self.ddl_concurrent_tasks, ) @@ -970,25 +978,32 @@ def _promote_snapshot( snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None, table_mapping: t.Optional[t.Dict[str, str]] = None, ) -> None: - if snapshot.is_model: - adapter = ( - self.get_adapter(snapshot.model_gateway) - if environment_naming_info.gateway_managed - else self.adapter - ) - table_name = snapshot.table_name(deployability_index.is_representative(snapshot)) - view_name = snapshot.qualified_view_name.for_environment( - environment_naming_info, dialect=adapter.dialect - ) - render_kwargs: t.Dict[str, t.Any] = dict( - start=start, - end=end, - execution_time=execution_time, - engine_adapter=adapter, - deployability_index=deployability_index, - table_mapping=table_mapping, - runtime_stage=RuntimeStage.PROMOTING, - ) + if not snapshot.is_model: + return + + adapter = ( + self.get_adapter(snapshot.model_gateway) + if environment_naming_info.gateway_managed + else self.adapter + ) + table_name = snapshot.table_name(deployability_index.is_representative(snapshot)) + view_name = snapshot.qualified_view_name.for_environment( + environment_naming_info, dialect=adapter.dialect + ) + render_kwargs: t.Dict[str, t.Any] = dict( + start=start, + end=end, + execution_time=execution_time, + engine_adapter=adapter, + deployability_index=deployability_index, + table_mapping=table_mapping, + runtime_stage=RuntimeStage.PROMOTING, + ) + + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_kwargs)), + ): _evaluation_strategy(snapshot, adapter).promote( table_name=table_name, view_name=view_name, @@ -1007,10 +1022,15 @@ def _promote_snapshot( def _demote_snapshot( self, - snapshot: SnapshotInfoLike, + snapshot: Snapshot, environment_naming_info: EnvironmentNamingInfo, + deployability_index: t.Optional[DeployabilityIndex], on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], + table_mapping: t.Optional[t.Dict[str, str]] = None, ) -> None: + if not snapshot.is_model: + return + adapter = ( self.get_adapter(snapshot.model_gateway) if environment_naming_info.gateway_managed @@ -1019,7 +1039,18 @@ def _demote_snapshot( view_name = snapshot.qualified_view_name.for_environment( environment_naming_info, dialect=adapter.dialect ) - _evaluation_strategy(snapshot, adapter).demote(view_name) + with ( + adapter.transaction(), + adapter.session( + snapshot.model.render_session_properties( + engine_adapter=adapter, + deployability_index=deployability_index, + table_mapping=table_mapping, + runtime_stage=RuntimeStage.DEMOTING, + ) + ), + ): + _evaluation_strategy(snapshot, adapter).demote(view_name) if on_complete is not None: on_complete(snapshot) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 2888511ba1..93cef90daf 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -294,6 +294,8 @@ def test_promote(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + adapter_mock.transaction.assert_called() + adapter_mock.session.assert_called() adapter_mock.create_schema.assert_called_once_with(to_schema("test_schema__test_env")) adapter_mock.create_view.assert_called_once_with( "test_schema__test_env.test_model", @@ -320,6 +322,8 @@ def test_demote(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator.demote([snapshot], EnvironmentNamingInfo(name="test_env")) + adapter_mock.transaction.assert_called() + adapter_mock.session.assert_called() adapter_mock.drop_view.assert_called_once_with( "test_schema__test_env.test_model", cascade=False,