From a17b90557b21bd7dda69f1b8324a1827b03cc230 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Mon, 8 Sep 2025 01:49:36 +0000 Subject: [PATCH 1/4] Feat: prevent other processes seeing missing intervals during restatement --- sqlmesh/core/console.py | 57 +++++ sqlmesh/core/plan/evaluator.py | 70 ++++-- sqlmesh/core/plan/explainer.py | 91 +++++++- sqlmesh/core/plan/stages.py | 38 +-- tests/core/test_integration.py | 406 ++++++++++++++++++++++++++++++++- tests/core/test_plan_stages.py | 353 +++++++++++++++++++++++++++- 6 files changed, 960 insertions(+), 55 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index af28f75932..11358f5095 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -551,6 +551,23 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: """Display list of models that failed during evaluation to the user.""" + @abc.abstractmethod + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment: EnvironmentSummary, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + """Display a list of models where new versions got deployed to the specified :environment while we were restating data the old versions + + Args: + snapshots: a list of (snapshot_we_restated, snapshot_it_got_replaced_with_during_restatement) tuples + environment: which environment got updated while we were restating models + environment_naming_info: how snapshots are named in that :environment (for display name purposes) + default_catalog: the configured default catalog (for display name purposes) + """ + @abc.abstractmethod def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: """Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message.""" @@ -771,6 +788,15 @@ def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: pass + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment: EnvironmentSummary, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + def log_destructive_change( self, snapshot_name: str, @@ -2225,6 +2251,37 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: for node_name, msg in error_messages.items(): self._print(f" [red]{node_name}[/red]\n\n{msg}") + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment: EnvironmentSummary, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + ) -> None: + if snapshots: + tree = Tree( + f"[yellow]The following models had new versions deployed in plan '{environment.plan_id}' while data was being restated:[/yellow]" + ) + + for restated_snapshot, updated_snapshot in snapshots: + display_name = restated_snapshot.display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + current_branch = tree.add(display_name) + current_branch.add(f"restated version: '{restated_snapshot.version}'") + current_branch.add(f"currently active version: '{updated_snapshot.version}'") + + self._print(tree) + + self.log_warning( + f"\nThe '{environment.name}' environment currently points to [bold]different[/bold] versions of these models, not the versions that just got restated." + ) + self._print( + "[yellow]If this is undesirable, please re-run this restatement plan which will apply it to the most recent versions of these models.[/yellow]\n" + ) + def log_destructive_change( self, snapshot_name: str, diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 79053e018b..cf5012bc4e 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -22,7 +22,7 @@ from sqlmesh.core.console import Console, get_console from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements from sqlmesh.core.macros import RuntimeStage -from sqlmesh.core.snapshot.definition import to_view_mapping +from sqlmesh.core.snapshot.definition import to_view_mapping, SnapshotTableInfo from sqlmesh.core.plan import stages from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.scheduler import Scheduler @@ -287,34 +287,62 @@ def visit_audit_only_run_stage( def visit_restatement_stage( self, stage: stages.RestatementStage, plan: EvaluatablePlan ) -> None: - snapshot_intervals_to_restate = { - (s.id_and_version, i) for s, i in stage.snapshot_intervals.items() - } - - # Restating intervals on prod plans should mean that the intervals are cleared across - # all environments, not just the version currently in prod - # This ensures that work done in dev environments can still be promoted to prod - # by forcing dev environments to re-run intervals that changed in prod + # Restating intervals on prod plans means that once the data for the intervals being restated has been backfilled + # (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments. + # + # This ensures that work done in dev environments can still be promoted to prod by forcing dev environments to + # re-run intervals that changed in prod (because after this stage runs they are cleared from state and thus show as missing) + # + # It also means that any new dev environments created while this restatement plan was running also get the + # correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan + # was created, which could have been several hours ago if there was a lot of data to restate. # # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod - snapshot_intervals_to_restate.update( - { - (s.snapshot, s.interval) - for s in identify_restatement_intervals_across_snapshot_versions( - state_reader=self.state_sync, - prod_restatements=plan.restatements, - disable_restatement_models=plan.disabled_restatement_models, - loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, - current_ts=to_timestamp(plan.execution_time or now()), - ).values() - } + + intervals_to_clear = identify_restatement_intervals_across_snapshot_versions( + state_reader=self.state_sync, + prod_restatements=plan.restatements, + disable_restatement_models=plan.disabled_restatement_models, + loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, + current_ts=to_timestamp(plan.execution_time or now()), ) + if not intervals_to_clear: + # Nothing to do + return + self.state_sync.remove_intervals( - snapshot_intervals=list(snapshot_intervals_to_restate), + snapshot_intervals=[(s.table_info, s.interval) for s in intervals_to_clear.values()], remove_shared_versions=plan.is_prod, ) + # While the restatements were being processed, did any of the snapshots being restated get new versions deployed? + # If they did, they will not reflect the data that just got restated, so we need to notify the user + if deployed_env := self.state_sync.get_environment(plan.environment.name): + promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots} + + deployed_during_restatement: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]] = [] + + for name in plan.restatements: + snapshot = stage.all_snapshots[name] + version = snapshot.table_info.version + if ( + prod_snapshot := promoted_snapshots_by_name.get(name) + ) and prod_snapshot.version != version: + deployed_during_restatement.append( + (snapshot.table_info, prod_snapshot.table_info) + ) + + if deployed_during_restatement: + self.console.log_models_updated_during_restatement( + deployed_during_restatement, + deployed_env.summary, + plan.environment.naming_info, + self.default_catalog, + ) + # note: the plan will automatically fail at the promotion stage with a ConflictingPlanError because the environment was changed by another plan + # so there is no need to explicitly fail the plan here + def visit_environment_record_update_stage( self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan ) -> None: diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py index ee829aeac1..19c0a8dda2 100644 --- a/sqlmesh/core/plan/explainer.py +++ b/sqlmesh/core/plan/explainer.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import abc import typing as t import logging +from dataclasses import dataclass from rich.console import Console as RichConsole from rich.tree import Tree @@ -8,6 +11,11 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import Console, TerminalConsole, get_console from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.snapshot.definition import DeployabilityIndex +from sqlmesh.core.plan.common import ( + SnapshotIntervalClearRequest, + identify_restatement_intervals_across_snapshot_versions, +) from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals from sqlmesh.core.plan import stages from sqlmesh.core.plan.evaluator import ( @@ -45,6 +53,15 @@ def evaluate( explainer_console = _get_explainer_console( self.console, plan.environment, self.default_catalog ) + + # add extra metadata that's only needed at this point for better --explain output + plan_stages = [ + ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan) + if isinstance(stage, stages.RestatementStage) + else stage + for stage in plan_stages + ] + explainer_console.explain(plan_stages) @@ -54,6 +71,61 @@ def explain(self, stages: t.List[stages.PlanStage]) -> None: pass +@dataclass +class ExplainableRestatementStage(stages.RestatementStage): + """ + This brings forward some calculations that would usually be done in the evaluator so the user can be given a better indication + of what might happen when they ask for the plan to be explained + """ + + snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest] + """Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name""" + + deployability_index: DeployabilityIndex + """Deployability of those snapshots (which arent necessarily present in the current plan so we cant use the + plan deployability index), used for outputting physical table names""" + + @classmethod + def from_restatement_stage( + cls: t.Type[ExplainableRestatementStage], + stage: stages.RestatementStage, + state_reader: StateReader, + plan: EvaluatablePlan, + ) -> ExplainableRestatementStage: + all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions( + state_reader=state_reader, + prod_restatements=plan.restatements, + disable_restatement_models=plan.disabled_restatement_models, + loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, + ) + + snapshot_intervals_to_clear = {} + deployability_index = DeployabilityIndex.all_deployable() + + if all_restatement_intervals: + snapshot_intervals_to_clear = { + s_id.name: r for s_id, r in all_restatement_intervals.items() + } + + # creating a deployability index over the "snapshot intervals to clear" + # allows us to print the physical names of the tables affected in the console output + # note that we can't use the DeployabilityIndex on the plan because it only includes + # snapshots for the current environment, not across all environments + deployability_index = DeployabilityIndex.create( + snapshots=state_reader.get_snapshots( + [s.snapshot_id for s in snapshot_intervals_to_clear.values()] + ), + start=plan.start, + start_override_per_model=plan.start_override_per_model, + ) + + return cls( + snapshot_intervals_to_clear=snapshot_intervals_to_clear, + deployability_index=deployability_index, + all_snapshots=stage.all_snapshots, + ) + + MAX_TREE_LENGTH = 10 @@ -146,11 +218,22 @@ def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree: tree.add(display_name) return tree - def visit_restatement_stage(self, stage: stages.RestatementStage) -> Tree: + def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage) -> Tree: + return self.visit_restatement_stage(stage) + + def visit_restatement_stage( + self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage] + ) -> Tree: tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]") - for snapshot_table_info, interval in stage.snapshot_intervals.items(): - display_name = self._display_name(snapshot_table_info) - tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]") + + if isinstance(stage, ExplainableRestatementStage) and ( + snapshot_intervals := stage.snapshot_intervals_to_clear + ): + for clear_request in snapshot_intervals.values(): + display_name = self._display_name(clear_request.table_info) + interval = clear_request.interval + tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]") + return tree def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree: diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index 91c8c6ff14..0d829a6739 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -12,7 +12,6 @@ Snapshot, SnapshotTableInfo, SnapshotId, - Interval, ) @@ -98,14 +97,19 @@ class AuditOnlyRunStage: @dataclass class RestatementStage: - """Restate intervals for given snapshots. + """Clear intervals from state for snapshots in *other* environments, when restatements are requested in prod. + + This stage is effectively a "marker" stage to trigger the plan evaluator to perform the "clear intervals" logic after the BackfillStage has completed. + The "clear intervals" logic is executed just-in-time using the latest state available in order to pick up new snapshots that may have + been created while the BackfillStage was running, which is why we do not build a list of snapshots to clear at plan time and defer to evaluation time. + + Note that this stage is only present on `prod` plans because dev plans do not need to worry about clearing intervals in other environments. Args: - snapshot_intervals: Intervals to restate. - all_snapshots: All snapshots in the plan by name. + all_snapshots: All snapshots in the plan by name. Note that this does not include the snapshots from other environments that will get their + intervals cleared, it's included here as an optimization to prevent having to re-fetch the current plan's snapshots """ - snapshot_intervals: t.Dict[SnapshotTableInfo, Interval] all_snapshots: t.Dict[str, Snapshot] @@ -321,10 +325,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: if audit_only_snapshots: stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values()))) - restatement_stage = self._get_restatement_stage(plan, snapshots_by_name) - if restatement_stage: - stages.append(restatement_stage) - if missing_intervals_before_promote: stages.append( BackfillStage( @@ -349,6 +349,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: ) ) + # note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage) + # needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them. + # in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data. + # we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces + # data for existing intervals and does not produce new ones + restatement_stage = self._get_restatement_stage(plan, snapshots_by_name) + if restatement_stage: + stages.append(restatement_stage) + stages.append( EnvironmentRecordUpdateStage( no_gaps_snapshot_names={s.name for s in before_promote_snapshots} @@ -443,15 +452,12 @@ def _get_after_all_stage( def _get_restatement_stage( self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot] ) -> t.Optional[RestatementStage]: - snapshot_intervals_to_restate = {} - for name, interval in plan.restatements.items(): - restated_snapshot = snapshots_by_name[name] - restated_snapshot.remove_interval(interval) - snapshot_intervals_to_restate[restated_snapshot.table_info] = interval - if not snapshot_intervals_to_restate or plan.is_dev: + if not plan.restatements or plan.is_dev: + # The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod return None + return RestatementStage( - snapshot_intervals=snapshot_intervals_to_restate, all_snapshots=snapshots_by_name + all_snapshots=snapshots_by_name, ) def _get_physical_layer_update_stage( diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index ef7c59ea7d..40cad93058 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -14,7 +14,7 @@ import pytest from pytest import MonkeyPatch from pathlib import Path -from sqlmesh.core.console import set_console, get_console, TerminalConsole +from sqlmesh.core.console import set_console, get_console, TerminalConsole, CaptureTerminalConsole from sqlmesh.core.config.naming import NameInferenceConfig from sqlmesh.core.model.common import ParsableSql from sqlmesh.utils.concurrency import NodeExecutionFailedError @@ -24,7 +24,9 @@ from sqlglot.expressions import DataType import re from IPython.utils.capture import capture_output - +from concurrent.futures import ThreadPoolExecutor, TimeoutError +import time +import queue from sqlmesh import CustomMaterialization from sqlmesh.cli.project_init import init_example_project @@ -72,7 +74,13 @@ SnapshotTableInfo, ) from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp -from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError +from sqlmesh.utils.errors import ( + NoChangesPlanError, + SQLMeshError, + PlanError, + ConfigError, + ConflictingPlanError, +) from sqlmesh.utils.pydantic import validate_string from tests.conftest import DuckDBMetadata, SushiDataValidator from sqlmesh.utils import CorrelationId @@ -10181,3 +10189,395 @@ def test_incremental_by_time_model_ignore_additive_change_unit_test(tmp_path: Pa assert test_result.testsRun == len(test_result.successes) context.close() + + +def test_restatement_plan_interval_external_visibility(tmp_path: Path): + """ + Scenario: + - `prod` environment exists, models A <- B + - `dev` environment created, models A <- B(dev) <- C (dev) + - Restatement plan is triggered against `prod` for model A + - During restatement, a new dev environment `dev_2` is created with a new version of B(dev_2) + + Outcome: + - At no point are the prod_intervals considered "missing" from state for A + - The intervals for B(dev) and C(dev) are cleared + - The intervals for B(dev_2) are also cleared even though the environment didnt exist at the time the plan was started, + because they are based on the data from a partially restated version of A + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + lock_file_path = tmp_path / "test.lock" # python model blocks while this file is present + + evaluation_lock_file_path = ( + tmp_path / "evaluation.lock" + ) # python model creates this file if it's in the wait loop and deletes it once done + + # Note: to make execution block so we can test stuff, we use a Python model that blocks until it no longer detects the presence of a file + (models_dir / "model_a.py").write_text(f""" +from sqlmesh.core.model import model +from sqlmesh.core.macros import MacroEvaluator + +@model( + "test.model_a", + is_sql=True, + kind="FULL" +) +def entrypoint(evaluator: MacroEvaluator) -> str: + from pathlib import Path + import time + + if evaluator.runtime_stage == 'evaluating': + while True: + if Path("{str(lock_file_path)}").exists(): + Path("{str(evaluation_lock_file_path)}").touch() + print("lock exists; sleeping") + time.sleep(2) + else: + Path("{str(evaluation_lock_file_path)}").unlink(missing_ok=True) + break + + return "select 'model_a' as m" +""") + + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb from test.model_a as a + """) + + config = Config( + gateways={ + "": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "db.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01"), + ) + ctx = Context(paths=[tmp_path], config=config) + + ctx.plan(environment="prod", auto_apply=True) + + assert len(ctx.snapshots) == 2 + assert all(s.intervals for s in ctx.snapshots.values()) + + prod_model_a_snapshot_id = ctx.snapshots['"db"."test"."model_a"'].snapshot_id + prod_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + + # dev models + # new version of B + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb, 'dev' as dev_version from test.model_a as a + """) + + # add C + (models_dir / "model_c.sql").write_text(""" + MODEL ( + name test.model_c, + kind FULL + ); + + select b.*, 'model_c' as mc from test.model_b as b + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + + dev_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + dev_model_c_snapshot_id = ctx.snapshots['"db"."test"."model_c"'].snapshot_id + + assert dev_model_b_snapshot_id != prod_model_b_snapshot_id + + # now, we restate A in prod but touch the lockfile so it hangs during evaluation + # we also have to do it in its own thread due to the hang + lock_file_path.touch() + + def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): + q.put("thread_started") + + # give this thread its own Context object to prevent segfaulting the Python interpreter + restatement_ctx = Context(paths=[tmp_path], config=config) + + # dev2 not present before the restatement plan starts + assert restatement_ctx.state_sync.get_environment("dev2") is None + + q.put("plan_started") + plan = restatement_ctx.plan( + environment="prod", restate_models=['"db"."test"."model_a"'], auto_apply=True + ) + q.put("plan_completed") + + # dev2 was created during the restatement plan + assert restatement_ctx.state_sync.get_environment("dev2") is not None + + return plan + + executor = ThreadPoolExecutor() + q: queue.Queue = queue.Queue() + restatement_plan_future = executor.submit(_run_restatement_plan, tmp_path, config, q) + assert q.get() == "thread_started" + + try: + if e := restatement_plan_future.exception(timeout=1): + # abort early if the plan thread threw an exception + raise e + except TimeoutError: + # that's ok, we dont actually expect the plan to have finished in 1 second + pass + + # while that restatement is running, we can simulate another process and check that it sees no empty intervals + assert q.get() == "plan_started" + + # dont check for potentially missing intervals until the plan is in the evaluation loop + attempts = 0 + while not evaluation_lock_file_path.exists(): + time.sleep(2) + attempts += 1 + if attempts > 10: + raise ValueError("Gave up waiting for evaluation loop") + + ctx.clear_caches() # get rid of the file cache so that data is re-fetched from state + prod_models_from_state = ctx.state_sync.get_snapshots( + snapshot_ids=[prod_model_a_snapshot_id, prod_model_b_snapshot_id] + ) + + # prod intervals should be present still + assert all(m.intervals for m in prod_models_from_state.values()) + + # so should dev intervals since prod restatement is still running + assert all(m.intervals for m in ctx.snapshots.values()) + + # now, lets create a new dev environment "dev2", while the prod restatement plan is still running, + # that changes model_b while still being based on the original version of model_a + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb, 'dev2' as dev_version from test.model_a as a + """) + ctx.load() + ctx.plan(environment="dev2", auto_apply=True) + + dev2_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + assert dev2_model_b_snapshot_id != dev_model_b_snapshot_id + assert dev2_model_b_snapshot_id != prod_model_b_snapshot_id + + # as at this point, everything still has intervals + ctx.clear_caches() + assert all( + s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + prod_model_a_snapshot_id, + prod_model_b_snapshot_id, + dev_model_b_snapshot_id, + dev_model_c_snapshot_id, + dev2_model_b_snapshot_id, + ] + ).values() + ) + + # now, we finally let that restatement plan complete + # first, verify it's still blocked where it should be + assert not restatement_plan_future.done() + + lock_file_path.unlink() # remove lock file, plan should be able to proceed now + + if e := restatement_plan_future.exception(): # blocks until future complete + raise e + + assert restatement_plan_future.result() + assert q.get() == "plan_completed" + + ctx.clear_caches() + + # check that intervals in prod are present + assert all( + s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + prod_model_a_snapshot_id, + prod_model_b_snapshot_id, + ] + ).values() + ) + + # check that intervals in dev have been cleared, including the dev2 env that + # was created after the restatement plan started + assert all( + not s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + dev_model_b_snapshot_id, + dev_model_c_snapshot_id, + dev2_model_b_snapshot_id, + ] + ).values() + ) + + executor.shutdown() + + +def test_restatement_plan_detects_prod_deployment_during_restatement(tmp_path: Path): + """ + Scenario: + - `prod` environment exists, model A + - `dev` environment created, model A(dev) + - Restatement plan is triggered against `prod` for model A + - During restatement, someone else deploys A(dev) to prod, replacing the model that is currently being restated. + + Outcome: + - The deployment plan for dev -> prod should succeed in deploying the new version + - The prod restatement plan should fail with a ConflictingPlanError and warn about the model that got updated while undergoing restatement + """ + orig_console = get_console() + console = CaptureTerminalConsole() + set_console(console) + + models_dir = tmp_path / "models" + models_dir.mkdir() + + lock_file_path = tmp_path / "test.lock" # python model blocks while this file is present + + evaluation_lock_file_path = ( + tmp_path / "evaluation.lock" + ) # python model creates this file if it's in the wait loop and deletes it once done + + # Note: to make execution block so we can test stuff, we use a Python model that blocks until it no longer detects the presence of a file + (models_dir / "model_a.py").write_text(f""" +from sqlmesh.core.model import model +from sqlmesh.core.macros import MacroEvaluator + +@model( + "test.model_a", + is_sql=True, + kind="FULL" +) +def entrypoint(evaluator: MacroEvaluator) -> str: + from pathlib import Path + import time + + if evaluator.runtime_stage == 'evaluating': + while True: + if Path("{str(lock_file_path)}").exists(): + Path("{str(evaluation_lock_file_path)}").touch() + print("lock exists; sleeping") + time.sleep(2) + else: + Path("{str(evaluation_lock_file_path)}").unlink(missing_ok=True) + break + + return "select 'model_a' as m" +""") + + config = Config( + gateways={ + "": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "db.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01"), + ) + ctx = Context(paths=[tmp_path], config=config) + + # create prod + ctx.plan(environment="prod", auto_apply=True) + original_prod = ctx.state_sync.get_environment("prod") + assert original_prod + + # update model_a for dev + (models_dir / "model_a.py").unlink() + (models_dir / "model_a.sql").write_text(""" + MODEL ( + name test.model_a, + kind FULL + ); + + select 1 as changed + """) + + # create dev + ctx.load() + plan = ctx.plan(environment="dev", auto_apply=True) + assert len(plan.modified_snapshots) == 1 + + # now, trigger a prod restatement plan in a different thread and block it to simulate a long restatement + def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): + q.put("thread_started") + + # give this thread its own Context object to prevent segfaulting the Python interpreter + restatement_ctx = Context(paths=[tmp_path], config=config) + + # ensure dev is present before the restatement plan starts + assert restatement_ctx.state_sync.get_environment("dev") is not None + + q.put("plan_started") + expected_error = None + try: + restatement_ctx.plan( + environment="prod", restate_models=['"db"."test"."model_a"'], auto_apply=True + ) + except ConflictingPlanError as e: + expected_error = e + + q.put("plan_completed") + return expected_error + + executor = ThreadPoolExecutor() + q: queue.Queue = queue.Queue() + lock_file_path.touch() + + restatement_plan_future = executor.submit(_run_restatement_plan, tmp_path, config, q) + restatement_plan_future.add_done_callback(lambda _: executor.shutdown()) + + assert q.get() == "thread_started" + + try: + if e := restatement_plan_future.exception(timeout=1): + # abort early if the plan thread threw an exception + raise e + except TimeoutError: + # that's ok, we dont actually expect the plan to have finished in 1 second + pass + + assert q.get() == "plan_started" + + # ok, now the prod restatement plan is running, let's deploy dev to prod + ctx.plan(environment="prod", auto_apply=True) + + new_prod = ctx.state_sync.get_environment("prod") + assert new_prod + assert new_prod.plan_id != original_prod.plan_id + assert new_prod.previous_plan_id == original_prod.plan_id + + # new prod is deployed but restatement plan is still running + assert not restatement_plan_future.done() + + # allow restatement plan to complete + lock_file_path.unlink() + + plan_error = restatement_plan_future.result() + assert isinstance(plan_error, ConflictingPlanError) + + output = " ".join(re.split("\s+", console.captured_output, flags=re.UNICODE)) + assert ( + f"The following models had new versions deployed in plan '{new_prod.plan_id}' while data was being restated: └── test.model_a" + in output + ) + assert "please re-run this restatement plan" in output + + set_console(orig_console) diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py index 744c7d18bf..930b1bb21f 100644 --- a/tests/core/test_plan_stages.py +++ b/tests/core/test_plan_stages.py @@ -6,6 +6,7 @@ from sqlmesh.core.config import EnvironmentSuffixTarget from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.core.model import SqlModel, ModelKindName +from sqlmesh.core.plan.common import SnapshotIntervalClearRequest from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.plan.stages import ( build_plan_stages, @@ -23,11 +24,13 @@ FinalizeEnvironmentStage, UnpauseStage, ) +from sqlmesh.core.plan.explainer import ExplainableRestatementStage from sqlmesh.core.snapshot.definition import ( SnapshotChangeCategory, DeployabilityIndex, Snapshot, SnapshotId, + SnapshotIdLike, ) from sqlmesh.core.state_sync import StateReader from sqlmesh.core.environment import Environment, EnvironmentStatements @@ -499,15 +502,25 @@ def test_build_plan_stages_basic_no_backfill( assert isinstance(stages[7], FinalizeEnvironmentStage) -def test_build_plan_stages_restatement( +def test_build_plan_stages_restatement_prod_only( snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture ) -> None: + """ + Scenario: + - Prod restatement triggered in a project with no dev environments + + Expected Outcome: + - Plan still contains a RestatementStage in case a dev environment was + created during restatement + """ + # Mock state reader to return existing snapshots and environment state_reader = mocker.Mock(spec=StateReader) state_reader.get_snapshots.return_value = { snapshot_a.snapshot_id: snapshot_a, snapshot_b.snapshot_id: snapshot_b, } + existing_environment = Environment( name="prod", snapshots=[snapshot_a.table_info, snapshot_b.table_info], @@ -518,7 +531,9 @@ def test_build_plan_stages_restatement( promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], finalized_ts=to_timestamp("2023-01-02"), ) + state_reader.get_environment.return_value = existing_environment + state_reader.get_environments_summary.return_value = [existing_environment.summary] environment = Environment( name="prod", @@ -577,17 +592,167 @@ def test_build_plan_stages_restatement( snapshot_b.snapshot_id, } - # Verify RestatementStage - restatement_stage = stages[1] + # Verify BackfillStage + backfill_stage = stages[1] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 2 + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + expected_backfill_interval = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + for intervals in backfill_stage.snapshot_to_intervals.values(): + assert intervals == expected_backfill_interval + + # Verify RestatementStage exists but is empty + restatement_stage = stages[2] assert isinstance(restatement_stage, RestatementStage) - assert len(restatement_stage.snapshot_intervals) == 2 - expected_interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) - for snapshot_info, interval in restatement_stage.snapshot_intervals.items(): - assert interval == expected_interval - assert snapshot_info.name in ('"a"', '"b"') + restatement_stage = ExplainableRestatementStage.from_restatement_stage( + restatement_stage, state_reader, plan + ) + assert not restatement_stage.snapshot_intervals_to_clear + assert ( + restatement_stage.deployability_index == DeployabilityIndex.all_deployable() + ) # default index + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[4], FinalizeEnvironmentStage) + + +def test_build_plan_stages_restatement_prod_identifies_dev_intervals( + snapshot_a: Snapshot, + snapshot_b: Snapshot, + make_snapshot: t.Callable[..., Snapshot], + mocker: MockerFixture, +) -> None: + """ + Scenario: + - Prod restatement triggered in a project with a dev environment + - The dev environment contains a different physical version of the affected model + + Expected Outcome: + - Plan contains a RestatementStage that highlights the affected dev version + """ + # Dev version of snapshot_a, same name but different version + snapshot_a_dev = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, changed, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + snapshot_a_dev.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot_a_dev.snapshot_id != snapshot_a.snapshot_id + assert snapshot_a_dev.table_info != snapshot_a.table_info + + # Mock state reader to return existing snapshots and environment + state_reader = mocker.Mock(spec=StateReader) + snapshots_in_state = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_a_dev.snapshot_id: snapshot_a_dev, + } + + def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]): + return { + k: v + for k, v in snapshots_in_state.items() + if k in {s.snapshot_id for s in snapshot_ids} + } + + state_reader.get_snapshots.side_effect = _get_snapshots + state_reader.get_snapshot_ids_by_names.return_value = set() + + existing_prod_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + # dev has new version of snapshot_a but same version of snapshot_b + existing_dev_environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + state_reader.get_environment.side_effect = ( + lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + ) + state_reader.get_environments_summary.return_value = [ + existing_prod_environment.summary, + existing_dev_environment.summary, + ] + + environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + # Create evaluatable plan with restatements + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], # No new snapshots + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={ + '"a"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + '"b"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + }, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], # No changes + indirectly_modified_snapshots={}, # No changes + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[0] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } # Verify BackfillStage - backfill_stage = stages[2] + backfill_stage = stages[1] assert isinstance(backfill_stage, BackfillStage) assert len(backfill_stage.snapshot_to_intervals) == 2 assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() @@ -595,6 +760,24 @@ def test_build_plan_stages_restatement( for intervals in backfill_stage.snapshot_to_intervals.values(): assert intervals == expected_backfill_interval + # Verify RestatementStage + restatement_stage = stages[2] + assert isinstance(restatement_stage, RestatementStage) + restatement_stage = ExplainableRestatementStage.from_restatement_stage( + restatement_stage, state_reader, plan + ) + + # note: we only clear the intervals from state for "a" in dev, we leave prod alone + assert restatement_stage.snapshot_intervals_to_clear + assert len(restatement_stage.snapshot_intervals_to_clear) == 1 + assert restatement_stage.deployability_index is not None + snapshot_name, clear_request = list(restatement_stage.snapshot_intervals_to_clear.items())[0] + assert isinstance(clear_request, SnapshotIntervalClearRequest) + assert snapshot_name == '"a"' + assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id + assert clear_request.table_info == snapshot_a_dev.table_info + assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + # Verify EnvironmentRecordUpdateStage assert isinstance(stages[3], EnvironmentRecordUpdateStage) @@ -602,6 +785,155 @@ def test_build_plan_stages_restatement( assert isinstance(stages[4], FinalizeEnvironmentStage) +def test_build_plan_stages_restatement_dev_does_not_clear_intervals( + snapshot_a: Snapshot, + snapshot_b: Snapshot, + make_snapshot: t.Callable[..., Snapshot], + mocker: MockerFixture, +) -> None: + """ + Scenario: + - Restatement triggered against the dev environment + + Expected Outcome: + - BackfillStage only touches models in that dev environment + - Plan does not contain a RestatementStage because making changes in dev doesnt mean we need + to clear intervals from other environments + """ + # Dev version of snapshot_a, same name but different version + snapshot_a_dev = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, changed, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + snapshot_a_dev.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot_a_dev.snapshot_id != snapshot_a.snapshot_id + assert snapshot_a_dev.table_info != snapshot_a.table_info + + # Mock state reader to return existing snapshots and environment + state_reader = mocker.Mock(spec=StateReader) + snapshots_in_state = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_a_dev.snapshot_id: snapshot_a_dev, + } + state_reader.get_snapshots.side_effect = lambda snapshot_info_like: { + k: v + for k, v in snapshots_in_state.items() + if k in [sil.snapshot_id for sil in snapshot_info_like] + } + + # prod has snapshot_a, snapshot_b + existing_prod_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_prod_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + # dev has new version of snapshot_a + existing_dev_environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_dev_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + state_reader.get_environment.side_effect = ( + lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + ) + state_reader.get_environments_summary.return_value = [ + existing_prod_environment.summary, + existing_dev_environment.summary, + ] + + environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_dev_plan", + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id], + ) + + # Create evaluatable plan with restatements + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], # No new snapshots + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={ + '"a"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + }, + is_dev=True, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], # No changes + indirectly_modified_snapshots={}, # No changes + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + # Verify no RestatementStage + assert not any(s for s in stages if isinstance(s, RestatementStage)) + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[0] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 1 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a_dev.snapshot_id, + } + + # Verify BackfillStage + backfill_stage = stages[1] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 1 + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + backfill_snapshot, backfill_intervals = list(backfill_stage.snapshot_to_intervals.items())[0] + assert backfill_snapshot.snapshot_id == snapshot_a_dev.snapshot_id + assert backfill_intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[2], EnvironmentRecordUpdateStage) + + # Verify VirtualLayerUpdateStage (all non-prod plans get this regardless) + assert isinstance(stages[3], VirtualLayerUpdateStage) + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[4], FinalizeEnvironmentStage) + + def test_build_plan_stages_forward_only( snapshot_a: Snapshot, snapshot_b: Snapshot, make_snapshot, mocker: MockerFixture ) -> None: @@ -1686,6 +2018,7 @@ def test_adjust_intervals_restatement_removal( state_reader.refresh_snapshot_intervals = mocker.Mock() state_reader.get_snapshots.return_value = {} state_reader.get_environment.return_value = None + state_reader.get_environments_summary.return_value = [] environment = Environment( snapshots=[snapshot_a.table_info, snapshot_b.table_info], @@ -1738,8 +2071,6 @@ def test_adjust_intervals_restatement_removal( restatement_stages = [stage for stage in stages if isinstance(stage, RestatementStage)] assert len(restatement_stages) == 1 - restatement_stage = restatement_stages[0] - assert len(restatement_stage.snapshot_intervals) == 2 backfill_stages = [stage for stage in stages if isinstance(stage, BackfillStage)] assert len(backfill_stages) == 1 From 00b3d893f7a310f251268f810a681037a1e1c65b Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Mon, 8 Sep 2025 02:24:13 +0000 Subject: [PATCH 2/4] fix tests --- tests/core/test_plan_stages.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py index 930b1bb21f..0e48dd8a2e 100644 --- a/tests/core/test_plan_stages.py +++ b/tests/core/test_plan_stages.py @@ -520,6 +520,10 @@ def test_build_plan_stages_restatement_prod_only( snapshot_a.snapshot_id: snapshot_a, snapshot_b.snapshot_id: snapshot_b, } + state_reader.get_snapshots_by_names.return_value = { + snapshot_a.id_and_version, + snapshot_b.id_and_version, + } existing_environment = Environment( name="prod", @@ -661,7 +665,7 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]): } state_reader.get_snapshots.side_effect = _get_snapshots - state_reader.get_snapshot_ids_by_names.return_value = set() + state_reader.get_snapshots_by_names.return_value = set() existing_prod_environment = Environment( name="prod", From c1da66a649ea5cd2d910a1da74667897cdb2bb3d Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Wed, 10 Sep 2025 02:03:50 +0000 Subject: [PATCH 3/4] PR feedback --- sqlmesh/core/console.py | 13 ++----- sqlmesh/core/plan/evaluator.py | 54 +++++++++++++++++++---------- sqlmesh/core/plan/explainer.py | 52 +++++++++------------------ sqlmesh/core/snapshot/definition.py | 14 +++++++- tests/core/test_integration.py | 16 +++++++-- tests/core/test_plan_stages.py | 6 +--- 6 files changed, 81 insertions(+), 74 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 11358f5095..3b6cb1ce07 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -555,7 +555,6 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: def log_models_updated_during_restatement( self, snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], - environment: EnvironmentSummary, environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: @@ -791,7 +790,6 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: def log_models_updated_during_restatement( self, snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], - environment: EnvironmentSummary, environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: @@ -2254,13 +2252,12 @@ def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: def log_models_updated_during_restatement( self, snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], - environment: EnvironmentSummary, environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str] = None, ) -> None: if snapshots: tree = Tree( - f"[yellow]The following models had new versions deployed in plan '{environment.plan_id}' while data was being restated:[/yellow]" + f"[yellow]The following models had new versions deployed while data was being restated:[/yellow]" ) for restated_snapshot, updated_snapshot in snapshots: @@ -2274,13 +2271,7 @@ def log_models_updated_during_restatement( current_branch.add(f"currently active version: '{updated_snapshot.version}'") self._print(tree) - - self.log_warning( - f"\nThe '{environment.name}' environment currently points to [bold]different[/bold] versions of these models, not the versions that just got restated." - ) - self._print( - "[yellow]If this is undesirable, please re-run this restatement plan which will apply it to the most recent versions of these models.[/yellow]\n" - ) + self._print("") # newline spacer def log_destructive_change( self, diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index cf5012bc4e..03ecb770bf 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -40,7 +40,7 @@ from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions from sqlmesh.utils import CorrelationId from sqlmesh.utils.concurrency import NodeExecutionFailedError -from sqlmesh.utils.errors import PlanError, SQLMeshError +from sqlmesh.utils.errors import PlanError, ConflictingPlanError, SQLMeshError from sqlmesh.utils.date import now, to_timestamp logger = logging.getLogger(__name__) @@ -311,37 +311,53 @@ def visit_restatement_stage( # Nothing to do return - self.state_sync.remove_intervals( - snapshot_intervals=[(s.table_info, s.interval) for s in intervals_to_clear.values()], - remove_shared_versions=plan.is_prod, - ) - # While the restatements were being processed, did any of the snapshots being restated get new versions deployed? # If they did, they will not reflect the data that just got restated, so we need to notify the user + deployed_during_restatement: t.Dict[ + str, t.Tuple[SnapshotTableInfo, SnapshotTableInfo] + ] = {} # tuple of (restated_snapshot, current_prod_snapshot) + if deployed_env := self.state_sync.get_environment(plan.environment.name): promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots} - deployed_during_restatement: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]] = [] - for name in plan.restatements: snapshot = stage.all_snapshots[name] version = snapshot.table_info.version if ( prod_snapshot := promoted_snapshots_by_name.get(name) ) and prod_snapshot.version != version: - deployed_during_restatement.append( - (snapshot.table_info, prod_snapshot.table_info) + deployed_during_restatement[name] = ( + snapshot.table_info, + prod_snapshot.table_info, ) - if deployed_during_restatement: - self.console.log_models_updated_during_restatement( - deployed_during_restatement, - deployed_env.summary, - plan.environment.naming_info, - self.default_catalog, - ) - # note: the plan will automatically fail at the promotion stage with a ConflictingPlanError because the environment was changed by another plan - # so there is no need to explicitly fail the plan here + # we need to *not* clear the intervals on the snapshots where new versions were deployed while the restatement was running in order to prevent + # subsequent plans from having unexpected intervals to backfill. + # we instead list the affected models and abort the plan with an error so the user can decide what to do + # (either re-attempt the restatement plan or leave things as they are) + filtered_intervals_to_clear = [ + (s.snapshot, s.interval) + for s in intervals_to_clear.values() + if s.snapshot.name not in deployed_during_restatement + ] + + if filtered_intervals_to_clear: + # We still clear intervals in other envs for models that were successfully restated without having new versions promoted during restatement + self.state_sync.remove_intervals( + snapshot_intervals=filtered_intervals_to_clear, + remove_shared_versions=plan.is_prod, + ) + + if deployed_env and deployed_during_restatement: + self.console.log_models_updated_during_restatement( + list(deployed_during_restatement.values()), + plan.environment.naming_info, + self.default_catalog, + ) + raise ConflictingPlanError( + f"Another plan ({deployed_env.summary.plan_id}) deployed new versions of {len(deployed_during_restatement)} models in the target environment '{plan.environment.name}' while they were being restated by this plan.\n" + "Please re-apply your plan if these new versions should be restated." + ) def visit_environment_record_update_stage( self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py index 19c0a8dda2..07fb507345 100644 --- a/sqlmesh/core/plan/explainer.py +++ b/sqlmesh/core/plan/explainer.py @@ -11,7 +11,7 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import Console, TerminalConsole, get_console from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.snapshot.definition import DeployabilityIndex +from sqlmesh.core.snapshot.definition import model_display_name from sqlmesh.core.plan.common import ( SnapshotIntervalClearRequest, identify_restatement_intervals_across_snapshot_versions, @@ -22,9 +22,7 @@ PlanEvaluator, ) from sqlmesh.core.state_sync import StateReader -from sqlmesh.core.snapshot.definition import ( - SnapshotInfoMixin, -) +from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotNameVersionLike from sqlmesh.utils import Verbosity, rich as srich, to_snake_case from sqlmesh.utils.date import to_ts from sqlmesh.utils.errors import SQLMeshError @@ -81,10 +79,6 @@ class ExplainableRestatementStage(stages.RestatementStage): snapshot_intervals_to_clear: t.Dict[str, SnapshotIntervalClearRequest] """Which snapshots from other environments would have intervals cleared as part of restatement, keyed by name""" - deployability_index: DeployabilityIndex - """Deployability of those snapshots (which arent necessarily present in the current plan so we cant use the - plan deployability index), used for outputting physical table names""" - @classmethod def from_restatement_stage( cls: t.Type[ExplainableRestatementStage], @@ -99,29 +93,10 @@ def from_restatement_stage( loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, ) - snapshot_intervals_to_clear = {} - deployability_index = DeployabilityIndex.all_deployable() - - if all_restatement_intervals: - snapshot_intervals_to_clear = { - s_id.name: r for s_id, r in all_restatement_intervals.items() - } - - # creating a deployability index over the "snapshot intervals to clear" - # allows us to print the physical names of the tables affected in the console output - # note that we can't use the DeployabilityIndex on the plan because it only includes - # snapshots for the current environment, not across all environments - deployability_index = DeployabilityIndex.create( - snapshots=state_reader.get_snapshots( - [s.snapshot_id for s in snapshot_intervals_to_clear.values()] - ), - start=plan.start, - start_override_per_model=plan.start_override_per_model, - ) - return cls( - snapshot_intervals_to_clear=snapshot_intervals_to_clear, - deployability_index=deployability_index, + snapshot_intervals_to_clear={ + s.snapshot.name: s for s in all_restatement_intervals.values() + }, all_snapshots=stage.all_snapshots, ) @@ -230,7 +205,7 @@ def visit_restatement_stage( snapshot_intervals := stage.snapshot_intervals_to_clear ): for clear_request in snapshot_intervals.values(): - display_name = self._display_name(clear_request.table_info) + display_name = self._display_name(clear_request.snapshot) interval = clear_request.interval tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]") @@ -348,15 +323,22 @@ def visit_finalize_environment_stage( def _display_name( self, - snapshot: SnapshotInfoMixin, + snapshot: t.Union[SnapshotInfoMixin, SnapshotNameVersionLike], environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, ) -> str: - return snapshot.display_name( - environment_naming_info or self.environment_naming_info, - self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + naming_kwargs: t.Any = dict( + environment_naming_info=environment_naming_info or self.environment_naming_info, + default_catalog=self.default_catalog + if self.verbosity < Verbosity.VERY_VERBOSE + else None, dialect=self.dialect, ) + if isinstance(snapshot, SnapshotInfoMixin): + return snapshot.display_name(**naming_kwargs) + + return model_display_name(node_name=snapshot.name, **naming_kwargs) + def _limit_tree(self, tree: Tree) -> Tree: tree_length = len(tree.children) if tree_length <= MAX_TREE_LENGTH: diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index c17e94be10..8812ca0977 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1788,7 +1788,19 @@ def display_name( """ if snapshot_info_like.is_audit: return snapshot_info_like.name - view_name = exp.to_table(snapshot_info_like.name) + + return model_display_name( + snapshot_info_like.name, environment_naming_info, default_catalog, dialect + ) + + +def model_display_name( + node_name: str, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + dialect: DialectType = None, +) -> str: + view_name = exp.to_table(node_name) catalog = ( None diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 40cad93058..0fad472cd5 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -10440,8 +10440,9 @@ def test_restatement_plan_detects_prod_deployment_during_restatement(tmp_path: P - During restatement, someone else deploys A(dev) to prod, replacing the model that is currently being restated. Outcome: - - The deployment plan for dev -> prod should succeed in deploying the new version + - The deployment plan for dev -> prod should succeed in deploying the new version of A - The prod restatement plan should fail with a ConflictingPlanError and warn about the model that got updated while undergoing restatement + - The new version of A should have no intervals cleared. The user needs to rerun the restatement if the intervals should still be cleared """ orig_console = get_console() console = CaptureTerminalConsole() @@ -10514,6 +10515,7 @@ def entrypoint(evaluator: MacroEvaluator) -> str: ctx.load() plan = ctx.plan(environment="dev", auto_apply=True) assert len(plan.modified_snapshots) == 1 + new_model_a_snapshot_id = list(plan.modified_snapshots)[0] # now, trigger a prod restatement plan in a different thread and block it to simulate a long restatement def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): @@ -10572,12 +10574,20 @@ def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): plan_error = restatement_plan_future.result() assert isinstance(plan_error, ConflictingPlanError) + assert "please re-apply your plan" in repr(plan_error).lower() output = " ".join(re.split("\s+", console.captured_output, flags=re.UNICODE)) assert ( - f"The following models had new versions deployed in plan '{new_prod.plan_id}' while data was being restated: └── test.model_a" + f"The following models had new versions deployed while data was being restated: └── test.model_a" in output ) - assert "please re-run this restatement plan" in output + + # check that no intervals have been cleared from the model_a currently in prod + model_a = ctx.state_sync.get_snapshots(snapshot_ids=[new_model_a_snapshot_id])[ + new_model_a_snapshot_id + ] + assert isinstance(model_a.node, SqlModel) + assert model_a.node.render_query_or_raise().sql() == 'SELECT 1 AS "changed"' + assert len(model_a.intervals) set_console(orig_console) diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py index 0e48dd8a2e..4ada7d458d 100644 --- a/tests/core/test_plan_stages.py +++ b/tests/core/test_plan_stages.py @@ -612,9 +612,6 @@ def test_build_plan_stages_restatement_prod_only( restatement_stage, state_reader, plan ) assert not restatement_stage.snapshot_intervals_to_clear - assert ( - restatement_stage.deployability_index == DeployabilityIndex.all_deployable() - ) # default index # Verify EnvironmentRecordUpdateStage assert isinstance(stages[3], EnvironmentRecordUpdateStage) @@ -774,12 +771,11 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]): # note: we only clear the intervals from state for "a" in dev, we leave prod alone assert restatement_stage.snapshot_intervals_to_clear assert len(restatement_stage.snapshot_intervals_to_clear) == 1 - assert restatement_stage.deployability_index is not None snapshot_name, clear_request = list(restatement_stage.snapshot_intervals_to_clear.items())[0] assert isinstance(clear_request, SnapshotIntervalClearRequest) assert snapshot_name == '"a"' assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id - assert clear_request.table_info == snapshot_a_dev.table_info + assert clear_request.snapshot == snapshot_a_dev.id_and_version assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) # Verify EnvironmentRecordUpdateStage From acad2ededb6b0a3965e0e90c75ff142856b29b66 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Thu, 11 Sep 2025 21:57:24 +0000 Subject: [PATCH 4/4] Add display_name to SnapshotIdAndVersion --- sqlmesh/core/plan/explainer.py | 12 +++--------- sqlmesh/core/snapshot/definition.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py index 07fb507345..b722d00d58 100644 --- a/sqlmesh/core/plan/explainer.py +++ b/sqlmesh/core/plan/explainer.py @@ -11,7 +11,6 @@ from sqlmesh.core import constants as c from sqlmesh.core.console import Console, TerminalConsole, get_console from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.snapshot.definition import model_display_name from sqlmesh.core.plan.common import ( SnapshotIntervalClearRequest, identify_restatement_intervals_across_snapshot_versions, @@ -22,7 +21,7 @@ PlanEvaluator, ) from sqlmesh.core.state_sync import StateReader -from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotNameVersionLike +from sqlmesh.core.snapshot.definition import SnapshotInfoMixin, SnapshotIdAndVersion from sqlmesh.utils import Verbosity, rich as srich, to_snake_case from sqlmesh.utils.date import to_ts from sqlmesh.utils.errors import SQLMeshError @@ -323,10 +322,10 @@ def visit_finalize_environment_stage( def _display_name( self, - snapshot: t.Union[SnapshotInfoMixin, SnapshotNameVersionLike], + snapshot: t.Union[SnapshotInfoMixin, SnapshotIdAndVersion], environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, ) -> str: - naming_kwargs: t.Any = dict( + return snapshot.display_name( environment_naming_info=environment_naming_info or self.environment_naming_info, default_catalog=self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE @@ -334,11 +333,6 @@ def _display_name( dialect=self.dialect, ) - if isinstance(snapshot, SnapshotInfoMixin): - return snapshot.display_name(**naming_kwargs) - - return model_display_name(node_name=snapshot.name, **naming_kwargs) - def _limit_tree(self, tree: Tree) -> Tree: tree_length = len(tree.children) if tree_length <= MAX_TREE_LENGTH: diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 8812ca0977..9522366721 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -638,6 +638,16 @@ def dev_version(self) -> str: def model_kind_name(self) -> t.Optional[ModelKindName]: return self.kind_name_ + def display_name( + self, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + dialect: DialectType = None, + ) -> str: + return model_display_name( + self.name, environment_naming_info, default_catalog, dialect=dialect + ) + class Snapshot(PydanticModel, SnapshotInfoMixin): """A snapshot represents a node at a certain point in time.