Skip to content

Commit 12bd71e

Browse files
authored
Fix: establish transaction and session before snapshot promotion/demotion (#4899)
1 parent b3bd132 commit 12bd71e

File tree

6 files changed

+81
-28
lines changed

6 files changed

+81
-28
lines changed

docs/concepts/macros/macro_variables.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ SQLMesh provides additional predefined variables used to modify model behavior b
132132
* 'loading' - The project is being loaded into SQLMesh's runtime context.
133133
* 'creating' - The model tables are being created.
134134
* 'evaluating' - The model query logic is being evaluated.
135-
* 'promoting' - The model is being promoted in the target environment (virtual layer update).
135+
* 'promoting' - The model is being promoted in the target environment (view created during virtual layer update).
136+
* 'demoting' - The model is being demoted in the target environment (view dropped during virtual layer update).
136137
* 'auditing' - The audit is being run.
137138
* 'testing' - The model query logic is being evaluated in the context of a unit test.
138139
* @gateway - A string value containing the name of the current [gateway](../../guides/connections.md).

sqlmesh/core/macros.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class RuntimeStage(Enum):
6767
CREATING = "creating"
6868
EVALUATING = "evaluating"
6969
PROMOTING = "promoting"
70+
DEMOTING = "demoting"
7071
AUDITING = "auditing"
7172
TESTING = "testing"
7273
BEFORE_ALL = "before_all"

sqlmesh/core/plan/evaluator.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,11 @@ def visit_virtual_layer_update_stage(
340340
)
341341
if stage.demoted_environment_naming_info:
342342
self._demote_snapshots(
343-
stage.demoted_snapshots,
343+
[stage.all_snapshots[s.snapshot_id] for s in stage.demoted_snapshots],
344344
stage.demoted_environment_naming_info,
345+
deployability_index=stage.deployability_index,
345346
on_complete=lambda s: self.console.update_promotion_progress(s, False),
347+
snapshots=stage.all_snapshots,
346348
)
347349

348350
completed = True
@@ -382,12 +384,23 @@ def _promote_snapshots(
382384

383385
def _demote_snapshots(
384386
self,
385-
target_snapshots: t.Iterable[SnapshotTableInfo],
387+
target_snapshots: t.Iterable[Snapshot],
386388
environment_naming_info: EnvironmentNamingInfo,
389+
snapshots: t.Dict[SnapshotId, Snapshot],
390+
deployability_index: t.Optional[DeployabilityIndex] = None,
387391
on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None,
388392
) -> None:
389393
self.snapshot_evaluator.demote(
390-
target_snapshots, environment_naming_info, on_complete=on_complete
394+
target_snapshots,
395+
environment_naming_info,
396+
table_mapping=to_view_mapping(
397+
snapshots.values(),
398+
environment_naming_info,
399+
default_catalog=self.default_catalog,
400+
dialect=self.snapshot_evaluator.adapter.dialect,
401+
),
402+
deployability_index=deployability_index,
403+
on_complete=on_complete,
391404
)
392405

393406
def _restatement_intervals_across_all_environments(

sqlmesh/core/plan/stages.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,14 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
361361
# Otherwise, unpause right after updatig the environment record.
362362
stages.append(UnpauseStage(promoted_snapshots=promoted_snapshots))
363363

364+
full_demoted_snapshots = self.state_reader.get_snapshots(
365+
s.snapshot_id for s in demoted_snapshots if s.snapshot_id not in snapshots
366+
)
364367
virtual_layer_update_stage = self._get_virtual_layer_update_stage(
365368
promoted_snapshots,
366369
demoted_snapshots,
367370
demoted_environment_naming_info,
368-
snapshots,
371+
snapshots | full_demoted_snapshots,
369372
deployability_index,
370373
)
371374
if virtual_layer_update_stage:

sqlmesh/core/snapshot/evaluator.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,10 @@ def promote(
276276

277277
def demote(
278278
self,
279-
target_snapshots: t.Iterable[SnapshotInfoLike],
279+
target_snapshots: t.Iterable[Snapshot],
280280
environment_naming_info: EnvironmentNamingInfo,
281+
table_mapping: t.Optional[t.Dict[str, str]] = None,
282+
deployability_index: t.Optional[DeployabilityIndex] = None,
281283
on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None,
282284
) -> None:
283285
"""Demotes the given collection of snapshots in the target environment by removing its view.
@@ -290,7 +292,13 @@ def demote(
290292
with self.concurrent_context():
291293
concurrent_apply_to_snapshots(
292294
target_snapshots,
293-
lambda s: self._demote_snapshot(s, environment_naming_info, on_complete),
295+
lambda s: self._demote_snapshot(
296+
s,
297+
environment_naming_info,
298+
deployability_index=deployability_index,
299+
on_complete=on_complete,
300+
table_mapping=table_mapping,
301+
),
294302
self.ddl_concurrent_tasks,
295303
)
296304

@@ -970,25 +978,32 @@ def _promote_snapshot(
970978
snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None,
971979
table_mapping: t.Optional[t.Dict[str, str]] = None,
972980
) -> None:
973-
if snapshot.is_model:
974-
adapter = (
975-
self.get_adapter(snapshot.model_gateway)
976-
if environment_naming_info.gateway_managed
977-
else self.adapter
978-
)
979-
table_name = snapshot.table_name(deployability_index.is_representative(snapshot))
980-
view_name = snapshot.qualified_view_name.for_environment(
981-
environment_naming_info, dialect=adapter.dialect
982-
)
983-
render_kwargs: t.Dict[str, t.Any] = dict(
984-
start=start,
985-
end=end,
986-
execution_time=execution_time,
987-
engine_adapter=adapter,
988-
deployability_index=deployability_index,
989-
table_mapping=table_mapping,
990-
runtime_stage=RuntimeStage.PROMOTING,
991-
)
981+
if not snapshot.is_model:
982+
return
983+
984+
adapter = (
985+
self.get_adapter(snapshot.model_gateway)
986+
if environment_naming_info.gateway_managed
987+
else self.adapter
988+
)
989+
table_name = snapshot.table_name(deployability_index.is_representative(snapshot))
990+
view_name = snapshot.qualified_view_name.for_environment(
991+
environment_naming_info, dialect=adapter.dialect
992+
)
993+
render_kwargs: t.Dict[str, t.Any] = dict(
994+
start=start,
995+
end=end,
996+
execution_time=execution_time,
997+
engine_adapter=adapter,
998+
deployability_index=deployability_index,
999+
table_mapping=table_mapping,
1000+
runtime_stage=RuntimeStage.PROMOTING,
1001+
)
1002+
1003+
with (
1004+
adapter.transaction(),
1005+
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
1006+
):
9921007
_evaluation_strategy(snapshot, adapter).promote(
9931008
table_name=table_name,
9941009
view_name=view_name,
@@ -1007,10 +1022,15 @@ def _promote_snapshot(
10071022

10081023
def _demote_snapshot(
10091024
self,
1010-
snapshot: SnapshotInfoLike,
1025+
snapshot: Snapshot,
10111026
environment_naming_info: EnvironmentNamingInfo,
1027+
deployability_index: t.Optional[DeployabilityIndex],
10121028
on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]],
1029+
table_mapping: t.Optional[t.Dict[str, str]] = None,
10131030
) -> None:
1031+
if not snapshot.is_model:
1032+
return
1033+
10141034
adapter = (
10151035
self.get_adapter(snapshot.model_gateway)
10161036
if environment_naming_info.gateway_managed
@@ -1019,7 +1039,18 @@ def _demote_snapshot(
10191039
view_name = snapshot.qualified_view_name.for_environment(
10201040
environment_naming_info, dialect=adapter.dialect
10211041
)
1022-
_evaluation_strategy(snapshot, adapter).demote(view_name)
1042+
with (
1043+
adapter.transaction(),
1044+
adapter.session(
1045+
snapshot.model.render_session_properties(
1046+
engine_adapter=adapter,
1047+
deployability_index=deployability_index,
1048+
table_mapping=table_mapping,
1049+
runtime_stage=RuntimeStage.DEMOTING,
1050+
)
1051+
),
1052+
):
1053+
_evaluation_strategy(snapshot, adapter).demote(view_name)
10231054

10241055
if on_complete is not None:
10251056
on_complete(snapshot)

tests/core/test_snapshot_evaluator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ def test_promote(mocker: MockerFixture, adapter_mock, make_snapshot):
294294

295295
evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env"))
296296

297+
adapter_mock.transaction.assert_called()
298+
adapter_mock.session.assert_called()
297299
adapter_mock.create_schema.assert_called_once_with(to_schema("test_schema__test_env"))
298300
adapter_mock.create_view.assert_called_once_with(
299301
"test_schema__test_env.test_model",
@@ -320,6 +322,8 @@ def test_demote(mocker: MockerFixture, adapter_mock, make_snapshot):
320322

321323
evaluator.demote([snapshot], EnvironmentNamingInfo(name="test_env"))
322324

325+
adapter_mock.transaction.assert_called()
326+
adapter_mock.session.assert_called()
323327
adapter_mock.drop_view.assert_called_once_with(
324328
"test_schema__test_env.test_model",
325329
cascade=False,

0 commit comments

Comments
 (0)