Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,7 +2022,34 @@ def _prompt_categorize(
plan = plan_builder.build()

if plan.restatements:
self._print("\n[bold]Restating models\n")
# A plan can have restatements for the following reasons:
# - The user specifically called `sqlmesh plan` with --restate-model.
# This creates a "restatement plan" which disallows all other changes and simply force-backfills
# the selected models and their downstream dependencies using the versions of the models stored in state.
# - There are no specific restatements (so changes are allowed) AND dev previews need to be computed.
# The "restatements" feature is currently reused for dev previews.
if plan.selected_models_to_restate:
# There were legitimate restatements, no dev previews
tree = Tree(
"[bold]Models selected for restatement:[/bold]\n"
"This causes backfill of the model itself as well as affected downstream models"
)
model_fqn_to_snapshot = {s.name: s for s in plan.snapshots.values()}
for model_fqn in plan.selected_models_to_restate:
snapshot = model_fqn_to_snapshot[model_fqn]
display_name = snapshot.display_name(
plan.environment_naming_info,
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
dialect=self.dialect,
)
tree.add(
display_name
) # note: we deliberately dont show any intervals here; they get shown in the backfill section
self._print(tree)
else:
# We are computing dev previews, do not confuse the user by printing out something to do
# with restatements. Dev previews are already highlighted in the backfill step
pass
else:
self.show_environment_difference_summary(
plan.context_diff,
Expand Down
1 change: 1 addition & 0 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def build(self) -> Plan:
directly_modified=directly_modified,
indirectly_modified=indirectly_modified,
deployability_index=deployability_index,
selected_models_to_restate=self._restate_models,
restatements=restatements,
start_override_per_model=self._start_override_per_model,
end_override_per_model=end_override_per_model,
Expand Down
9 changes: 9 additions & 0 deletions sqlmesh/core/plan/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@ class Plan(PydanticModel, frozen=True):
indirectly_modified: t.Dict[SnapshotId, t.Set[SnapshotId]]

deployability_index: DeployabilityIndex
selected_models_to_restate: t.Optional[t.Set[str]] = None
"""Models that have been explicitly selected for restatement by a user"""
restatements: t.Dict[SnapshotId, Interval]
"""
All models being restated, which are typically the explicitly selected ones + their downstream dependencies.
Note that dev previews are also considered restatements, so :selected_models_to_restate can be empty
while :restatements is still populated with dev previews
"""

start_override_per_model: t.Optional[t.Dict[str, datetime]]
end_override_per_model: t.Optional[t.Dict[str, datetime]]

Expand Down
45 changes: 34 additions & 11 deletions sqlmesh/core/plan/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as t
import logging
from dataclasses import dataclass
from collections import defaultdict

from rich.console import Console as RichConsole
from rich.tree import Tree
Expand All @@ -21,7 +22,11 @@
PlanEvaluator,
)
from sqlmesh.core.state_sync import StateReader
from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotIdAndVersion
from sqlmesh.core.snapshot.definition import (
SnapshotInfoMixin,
SnapshotIdAndVersion,
model_display_name,
)
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
from sqlmesh.utils.date import to_ts
from sqlmesh.utils.errors import SQLMeshError
Expand Down Expand Up @@ -75,8 +80,8 @@ class ExplainableRestatementStage(stages.RestatementStage):
of what might happen when they ask for the plan to be explained
"""

snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""
snapshot_intervals_to_clear: t.Dict[str, t.List[SnapshotIntervalClearRequest]]
"""Which snapshots from other environments would have intervals cleared as part of restatement, grouped by name."""

@classmethod
def from_restatement_stage(
Expand All @@ -92,10 +97,13 @@ def from_restatement_stage(
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
)

# Group the interval clear requests by snapshot name to make them easier to write to the console
snapshot_intervals_to_clear = defaultdict(list)
for clear_request in all_restatement_intervals.values():
snapshot_intervals_to_clear[clear_request.snapshot.name].append(clear_request)

return cls(
snapshot_intervals_to_clear={
s.snapshot.name: s for s in all_restatement_intervals.values()
},
snapshot_intervals_to_clear=snapshot_intervals_to_clear,
all_snapshots=stage.all_snapshots,
)

Expand Down Expand Up @@ -198,15 +206,30 @@ def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage
def visit_restatement_stage(
self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage]
) -> Tree:
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
tree = Tree(
"[bold]Invalidate data intervals in state for development environments to prevent old data from being promoted[/bold]\n"
"This only affects state and will not clear physical data from the tables until the next plan for each environment"
)

if isinstance(stage, ExplainableRestatementStage) and (
snapshot_intervals := stage.snapshot_intervals_to_clear
):
for clear_request in snapshot_intervals.values():
display_name = self._display_name(clear_request.snapshot)
interval = clear_request.interval
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
for name, clear_requests in snapshot_intervals.items():
display_name = model_display_name(
name, self.environment_naming_info, self.default_catalog, self.dialect
)
interval_start = min(cr.interval[0] for cr in clear_requests)
interval_end = max(cr.interval[1] for cr in clear_requests)

if not interval_start or not interval_end:
continue

node = tree.add(f"{display_name} [{to_ts(interval_start)} - {to_ts(interval_end)}]")

all_environment_names = sorted(
set(env_name for cr in clear_requests for env_name in cr.environment_names)
)
node.add("in environments: " + ", ".join(all_environment_names))

return tree

Expand Down
2 changes: 1 addition & 1 deletion tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_plan_restate_model(runner, tmp_path):
)
assert result.exit_code == 0
assert_duckdb_test(result)
assert "Restating models" in result.output
assert "Models selected for restatement" in result.output
assert "sqlmesh_example.full_model [full refresh" in result.output
assert_model_batches_executed(result)
assert "Virtual layer updated" not in result.output
Expand Down
6 changes: 4 additions & 2 deletions tests/core/test_plan_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,9 +771,11 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]):
# note: we only clear the intervals from state for "a" in dev, we leave prod alone
assert restatement_stage.snapshot_intervals_to_clear
assert len(restatement_stage.snapshot_intervals_to_clear) == 1
snapshot_name, clear_request = list(restatement_stage.snapshot_intervals_to_clear.items())[0]
assert isinstance(clear_request, SnapshotIntervalClearRequest)
snapshot_name, clear_requests = list(restatement_stage.snapshot_intervals_to_clear.items())[0]
assert snapshot_name == '"a"'
assert len(clear_requests) == 1
clear_request = clear_requests[0]
assert isinstance(clear_request, SnapshotIntervalClearRequest)
assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id
assert clear_request.snapshot == snapshot_a_dev.id_and_version
assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))
Expand Down