Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,22 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
"""Display list of models that failed during evaluation to the user."""

@abc.abstractmethod
def log_models_updated_during_restatement(
self,
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str],
) -> None:
"""Display a list of models where new versions got deployed to the specified :environment while we were restating data the old versions

Args:
snapshots: a list of (snapshot_we_restated, snapshot_it_got_replaced_with_during_restatement) tuples
environment: which environment got updated while we were restating models
environment_naming_info: how snapshots are named in that :environment (for display name purposes)
default_catalog: the configured default catalog (for display name purposes)
"""

@abc.abstractmethod
def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID:
"""Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message."""
Expand Down Expand Up @@ -771,6 +787,14 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
pass

def log_models_updated_during_restatement(
self,
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str],
) -> None:
pass

def log_destructive_change(
self,
snapshot_name: str,
Expand Down Expand Up @@ -2225,6 +2249,30 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None:
for node_name, msg in error_messages.items():
self._print(f" [red]{node_name}[/red]\n\n{msg}")

def log_models_updated_during_restatement(
self,
snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]],
environment_naming_info: EnvironmentNamingInfo,
default_catalog: t.Optional[str] = None,
) -> None:
if snapshots:
tree = Tree(
f"[yellow]The following models had new versions deployed while data was being restated:[/yellow]"
)

for restated_snapshot, updated_snapshot in snapshots:
display_name = restated_snapshot.display_name(
environment_naming_info,
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
dialect=self.dialect,
)
current_branch = tree.add(display_name)
current_branch.add(f"restated version: '{restated_snapshot.version}'")
current_branch.add(f"currently active version: '{updated_snapshot.version}'")

self._print(tree)
self._print("") # newline spacer

def log_destructive_change(
self,
snapshot_name: str,
Expand Down
94 changes: 69 additions & 25 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sqlmesh.core.console import Console, get_console
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
from sqlmesh.core.macros import RuntimeStage
from sqlmesh.core.snapshot.definition import to_view_mapping
from sqlmesh.core.snapshot.definition import to_view_mapping, SnapshotTableInfo
from sqlmesh.core.plan import stages
from sqlmesh.core.plan.definition import EvaluatablePlan
from sqlmesh.core.scheduler import Scheduler
Expand All @@ -40,7 +40,7 @@
from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions
from sqlmesh.utils import CorrelationId
from sqlmesh.utils.concurrency import NodeExecutionFailedError
from sqlmesh.utils.errors import PlanError, SQLMeshError
from sqlmesh.utils.errors import PlanError, ConflictingPlanError, SQLMeshError
from sqlmesh.utils.date import now, to_timestamp

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -287,34 +287,78 @@ def visit_audit_only_run_stage(
def visit_restatement_stage(
self, stage: stages.RestatementStage, plan: EvaluatablePlan
) -> None:
snapshot_intervals_to_restate = {
(s.id_and_version, i) for s, i in stage.snapshot_intervals.items()
}

# Restating intervals on prod plans should mean that the intervals are cleared across
# all environments, not just the version currently in prod
# This ensures that work done in dev environments can still be promoted to prod
# by forcing dev environments to re-run intervals that changed in prod
# Restating intervals on prod plans means that once the data for the intervals being restated has been backfilled
# (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments.
#
# This ensures that work done in dev environments can still be promoted to prod by forcing dev environments to
# re-run intervals that changed in prod (because after this stage runs they are cleared from state and thus show as missing)
#
# It also means that any new dev environments created while this restatement plan was running also get the
# correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan
# was created, which could have been several hours ago if there was a lot of data to restate.
#
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
snapshot_intervals_to_restate.update(
{
(s.snapshot, s.interval)
for s in identify_restatement_intervals_across_snapshot_versions(
state_reader=self.state_sync,
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
current_ts=to_timestamp(plan.execution_time or now()),
).values()
}
)

self.state_sync.remove_intervals(
snapshot_intervals=list(snapshot_intervals_to_restate),
remove_shared_versions=plan.is_prod,
intervals_to_clear = identify_restatement_intervals_across_snapshot_versions(
state_reader=self.state_sync,
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
current_ts=to_timestamp(plan.execution_time or now()),
)

if not intervals_to_clear:
# Nothing to do
return

# While the restatements were being processed, did any of the snapshots being restated get new versions deployed?
# If they did, they will not reflect the data that just got restated, so we need to notify the user
deployed_during_restatement: t.Dict[
str, t.Tuple[SnapshotTableInfo, SnapshotTableInfo]
] = {} # tuple of (restated_snapshot, current_prod_snapshot)

if deployed_env := self.state_sync.get_environment(plan.environment.name):
promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots}

for name in plan.restatements:
snapshot = stage.all_snapshots[name]
version = snapshot.table_info.version
if (
prod_snapshot := promoted_snapshots_by_name.get(name)
) and prod_snapshot.version != version:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it sufficient to just check the version? I'm sure the version will match if the snapshots match

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved internally, this is checking the version, := just got misread as =

deployed_during_restatement[name] = (
snapshot.table_info,
prod_snapshot.table_info,
)

# we need to *not* clear the intervals on the snapshots where new versions were deployed while the restatement was running in order to prevent
# subsequent plans from having unexpected intervals to backfill.
# we instead list the affected models and abort the plan with an error so the user can decide what to do
# (either re-attempt the restatement plan or leave things as they are)
filtered_intervals_to_clear = [
(s.snapshot, s.interval)
for s in intervals_to_clear.values()
if s.snapshot.name not in deployed_during_restatement
]

if filtered_intervals_to_clear:
# We still clear intervals in other envs for models that were successfully restated without having new versions promoted during restatement
self.state_sync.remove_intervals(
snapshot_intervals=filtered_intervals_to_clear,
remove_shared_versions=plan.is_prod,
)

if deployed_env and deployed_during_restatement:
self.console.log_models_updated_during_restatement(
list(deployed_during_restatement.values()),
plan.environment.naming_info,
self.default_catalog,
)
raise ConflictingPlanError(
f"Another plan ({deployed_env.summary.plan_id}) deployed new versions of {len(deployed_during_restatement)} models in the target environment '{plan.environment.name}' while they were being restated by this plan.\n"
"Please re-apply your plan if these new versions should be restated."
)

def visit_environment_record_update_stage(
self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan
) -> None:
Expand Down
79 changes: 69 additions & 10 deletions sqlmesh/core/plan/explainer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from __future__ import annotations

import abc
import typing as t
import logging
from dataclasses import dataclass

from rich.console import Console as RichConsole
from rich.tree import Tree
from sqlglot.dialects.dialect import DialectType
from sqlmesh.core import constants as c
from sqlmesh.core.console import Console, TerminalConsole, get_console
from sqlmesh.core.environment import EnvironmentNamingInfo
from sqlmesh.core.plan.common import (
SnapshotIntervalClearRequest,
identify_restatement_intervals_across_snapshot_versions,
)
from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals
from sqlmesh.core.plan import stages
from sqlmesh.core.plan.evaluator import (
PlanEvaluator,
)
from sqlmesh.core.state_sync import StateReader
from sqlmesh.core.snapshot.definition import (
SnapshotInfoMixin,
)
from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotIdAndVersion
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
from sqlmesh.utils.date import to_ts
from sqlmesh.utils.errors import SQLMeshError
Expand Down Expand Up @@ -45,6 +50,15 @@ def evaluate(
explainer_console = _get_explainer_console(
self.console, plan.environment, self.default_catalog
)

# add extra metadata that's only needed at this point for better --explain output
plan_stages = [
ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan)
if isinstance(stage, stages.RestatementStage)
else stage
for stage in plan_stages
]

explainer_console.explain(plan_stages)


Expand All @@ -54,6 +68,38 @@ def explain(self, stages: t.List[stages.PlanStage]) -> None:
pass


@dataclass
class ExplainableRestatementStage(stages.RestatementStage):
"""
This brings forward some calculations that would usually be done in the evaluator so the user can be given a better indication
of what might happen when they ask for the plan to be explained
"""

snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest]
"""Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name"""

@classmethod
def from_restatement_stage(
cls: t.Type[ExplainableRestatementStage],
stage: stages.RestatementStage,
state_reader: StateReader,
plan: EvaluatablePlan,
) -> ExplainableRestatementStage:
all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions(
state_reader=state_reader,
prod_restatements=plan.restatements,
disable_restatement_models=plan.disabled_restatement_models,
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
)

return cls(
snapshot_intervals_to_clear={
s.snapshot.name: s for s in all_restatement_intervals.values()
},
all_snapshots=stage.all_snapshots,
)


MAX_TREE_LENGTH = 10


Expand Down Expand Up @@ -146,11 +192,22 @@ def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree:
tree.add(display_name)
return tree

def visit_restatement_stage(self, stage: stages.RestatementStage) -> Tree:
def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage) -> Tree:
return self.visit_restatement_stage(stage)

def visit_restatement_stage(
self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage]
) -> Tree:
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
for snapshot_table_info, interval in stage.snapshot_intervals.items():
display_name = self._display_name(snapshot_table_info)
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")

if isinstance(stage, ExplainableRestatementStage) and (
snapshot_intervals := stage.snapshot_intervals_to_clear
):
for clear_request in snapshot_intervals.values():
display_name = self._display_name(clear_request.snapshot)
interval = clear_request.interval
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")

return tree

def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree:
Expand Down Expand Up @@ -265,12 +322,14 @@ def visit_finalize_environment_stage(

def _display_name(
self,
snapshot: SnapshotInfoMixin,
snapshot: t.Union[SnapshotInfoMixin, SnapshotIdAndVersion],
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
) -> str:
return snapshot.display_name(
environment_naming_info or self.environment_naming_info,
self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
environment_naming_info=environment_naming_info or self.environment_naming_info,
default_catalog=self.default_catalog
if self.verbosity < Verbosity.VERY_VERBOSE
else None,
dialect=self.dialect,
)

Expand Down
38 changes: 22 additions & 16 deletions sqlmesh/core/plan/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Snapshot,
SnapshotTableInfo,
SnapshotId,
Interval,
)


Expand Down Expand Up @@ -98,14 +97,19 @@ class AuditOnlyRunStage:

@dataclass
class RestatementStage:
"""Restate intervals for given snapshots.
"""Clear intervals from state for snapshots in *other* environments, when restatements are requested in prod.

This stage is effectively a "marker" stage to trigger the plan evaluator to perform the "clear intervals" logic after the BackfillStage has completed.
The "clear intervals" logic is executed just-in-time using the latest state available in order to pick up new snapshots that may have
been created while the BackfillStage was running, which is why we do not build a list of snapshots to clear at plan time and defer to evaluation time.

Note that this stage is only present on `prod` plans because dev plans do not need to worry about clearing intervals in other environments.

Args:
snapshot_intervals: Intervals to restate.
all_snapshots: All snapshots in the plan by name.
all_snapshots: All snapshots in the plan by name. Note that this does not include the snapshots from other environments that will get their
intervals cleared, it's included here as an optimization to prevent having to re-fetch the current plan's snapshots
"""

snapshot_intervals: t.Dict[SnapshotTableInfo, Interval]
all_snapshots: t.Dict[str, Snapshot]


Expand Down Expand Up @@ -321,10 +325,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
if audit_only_snapshots:
stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values())))

restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
if restatement_stage:
stages.append(restatement_stage)

if missing_intervals_before_promote:
stages.append(
BackfillStage(
Expand All @@ -349,6 +349,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
)
)

# note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage)
# needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them.
# in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data.
# we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces
# data for existing intervals and does not produce new ones
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
if restatement_stage:
stages.append(restatement_stage)

stages.append(
EnvironmentRecordUpdateStage(
no_gaps_snapshot_names={s.name for s in before_promote_snapshots}
Expand Down Expand Up @@ -443,15 +452,12 @@ def _get_after_all_stage(
def _get_restatement_stage(
self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]
) -> t.Optional[RestatementStage]:
snapshot_intervals_to_restate = {}
for name, interval in plan.restatements.items():
restated_snapshot = snapshots_by_name[name]
restated_snapshot.remove_interval(interval)
snapshot_intervals_to_restate[restated_snapshot.table_info] = interval
if not snapshot_intervals_to_restate or plan.is_dev:
if not plan.restatements or plan.is_dev:
# The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod
return None

return RestatementStage(
snapshot_intervals=snapshot_intervals_to_restate, all_snapshots=snapshots_by_name
all_snapshots=snapshots_by_name,
)

def _get_physical_layer_update_stage(
Expand Down
Loading