diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 3b6cb1ce07..d9567ae484 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -2022,7 +2022,34 @@ def _prompt_categorize( plan = plan_builder.build() if plan.restatements: - self._print("\n[bold]Restating models\n") + # A plan can have restatements for the following reasons: + # - The user specifically called `sqlmesh plan` with --restate-model. + # This creates a "restatement plan" which disallows all other changes and simply force-backfills + # the selected models and their downstream dependencies using the versions of the models stored in state. + # - There are no specific restatements (so changes are allowed) AND dev previews need to be computed. + # The "restatements" feature is currently reused for dev previews. + if plan.selected_models_to_restate: + # There were legitimate restatements, no dev previews + tree = Tree( + "[bold]Models selected for restatement:[/bold]\n" + "This causes backfill of the model itself as well as affected downstream models" + ) + model_fqn_to_snapshot = {s.name: s for s in plan.snapshots.values()} + for model_fqn in plan.selected_models_to_restate: + snapshot = model_fqn_to_snapshot[model_fqn] + display_name = snapshot.display_name( + plan.environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + tree.add( + display_name + ) # note: we deliberately dont show any intervals here; they get shown in the backfill section + self._print(tree) + else: + # We are computing dev previews, do not confuse the user by printing out something to do + # with restatements. Dev previews are already highlighted in the backfill step + pass else: self.show_environment_difference_summary( plan.context_diff, diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index a84b3b60dc..79af460d1d 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -338,6 +338,7 @@ def build(self) -> Plan: directly_modified=directly_modified, indirectly_modified=indirectly_modified, deployability_index=deployability_index, + selected_models_to_restate=self._restate_models, restatements=restatements, start_override_per_model=self._start_override_per_model, end_override_per_model=end_override_per_model, diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index aaf6ec5dc0..5ed3e4b188 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -58,7 +58,16 @@ class Plan(PydanticModel, frozen=True): indirectly_modified: t.Dict[SnapshotId, t.Set[SnapshotId]] deployability_index: DeployabilityIndex + selected_models_to_restate: t.Optional[t.Set[str]] = None + """Models that have been explicitly selected for restatement by a user""" restatements: t.Dict[SnapshotId, Interval] + """ + All models being restated, which are typically the explicitly selected ones + their downstream dependencies. + + Note that dev previews are also considered restatements, so :selected_models_to_restate can be empty + while :restatements is still populated with dev previews + """ + start_override_per_model: t.Optional[t.Dict[str, datetime]] end_override_per_model: t.Optional[t.Dict[str, datetime]] diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py index b722d00d58..f0a1e44aff 100644 --- a/sqlmesh/core/plan/explainer.py +++ b/sqlmesh/core/plan/explainer.py @@ -4,6 +4,7 @@ import typing as t import logging from dataclasses import dataclass +from collections import defaultdict from rich.console import Console as RichConsole from rich.tree import Tree @@ -21,7 +22,11 @@ PlanEvaluator, ) from sqlmesh.core.state_sync import StateReader -from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotIdAndVersion +from sqlmesh.core.snapshot.definition import ( + SnapshotInfoMixin, + SnapshotIdAndVersion, + model_display_name, +) from sqlmesh.utils import Verbosity, rich as srich, to_snake_case from sqlmesh.utils.date import to_ts from sqlmesh.utils.errors import SQLMeshError @@ -75,8 +80,8 @@ class ExplainableRestatementStage(stages.RestatementStage): of what might happen when they ask for the plan to be explained """ - snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest] - """Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name""" + snapshot_intervals_to_clear: t.Dict[str, t.List[SnapshotIntervalClearRequest]] + """Which snapshots from other environments would have intervals cleared as part of restatement, grouped by name.""" @classmethod def from_restatement_stage( @@ -92,10 +97,13 @@ def from_restatement_stage( loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, ) + # Group the interval clear requests by snapshot name to make them easier to write to the console + snapshot_intervals_to_clear = defaultdict(list) + for clear_request in all_restatement_intervals.values(): + snapshot_intervals_to_clear[clear_request.snapshot.name].append(clear_request) + return cls( - snapshot_intervals_to_clear={ - s.snapshot.name: s for s in all_restatement_intervals.values() - }, + snapshot_intervals_to_clear=snapshot_intervals_to_clear, all_snapshots=stage.all_snapshots, ) @@ -198,15 +206,30 @@ def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage def visit_restatement_stage( self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage] ) -> Tree: - tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]") + tree = Tree( + "[bold]Invalidate data intervals in state for development environments to prevent old data from being promoted[/bold]\n" + "This only affects state and will not clear physical data from the tables until the next plan for each environment" + ) if isinstance(stage, ExplainableRestatementStage) and ( snapshot_intervals := stage.snapshot_intervals_to_clear ): - for clear_request in snapshot_intervals.values(): - display_name = self._display_name(clear_request.snapshot) - interval = clear_request.interval - tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]") + for name, clear_requests in snapshot_intervals.items(): + display_name = model_display_name( + name, self.environment_naming_info, self.default_catalog, self.dialect + ) + interval_start = min(cr.interval[0] for cr in clear_requests) + interval_end = max(cr.interval[1] for cr in clear_requests) + + if not interval_start or not interval_end: + continue + + node = tree.add(f"{display_name} [{to_ts(interval_start)} - {to_ts(interval_end)}]") + + all_environment_names = sorted( + set(env_name for cr in clear_requests for env_name in cr.environment_names) + ) + node.add("in environments: " + ", ".join(all_environment_names)) return tree diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index e460387bbc..1cb41bddec 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -247,7 +247,7 @@ def test_plan_restate_model(runner, tmp_path): ) assert result.exit_code == 0 assert_duckdb_test(result) - assert "Restating models" in result.output + assert "Models selected for restatement" in result.output assert "sqlmesh_example.full_model [full refresh" in result.output assert_model_batches_executed(result) assert "Virtual layer updated" not in result.output diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py index 4ada7d458d..444ce1bb9b 100644 --- a/tests/core/test_plan_stages.py +++ b/tests/core/test_plan_stages.py @@ -771,9 +771,11 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]): # note: we only clear the intervals from state for "a" in dev, we leave prod alone assert restatement_stage.snapshot_intervals_to_clear assert len(restatement_stage.snapshot_intervals_to_clear) == 1 - snapshot_name, clear_request = list(restatement_stage.snapshot_intervals_to_clear.items())[0] - assert isinstance(clear_request, SnapshotIntervalClearRequest) + snapshot_name, clear_requests = list(restatement_stage.snapshot_intervals_to_clear.items())[0] assert snapshot_name == '"a"' + assert len(clear_requests) == 1 + clear_request = clear_requests[0] + assert isinstance(clear_request, SnapshotIntervalClearRequest) assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id assert clear_request.snapshot == snapshot_a_dev.id_and_version assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))