From c8c457512ee79afe7db33825469d289fe9377bcc Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Mon, 8 Sep 2025 01:46:07 +0000 Subject: [PATCH 1/3] Fix: When restating prod, clear intervals across all related snapshots, not just promoted ones --- sqlmesh/core/plan/common.py | 190 ++++++++++++++++++++++++++++++++- sqlmesh/core/plan/evaluator.py | 114 +++----------------- tests/core/test_integration.py | 174 ++++++++++++++++++++++++++++++ 3 files changed, 376 insertions(+), 102 deletions(-) diff --git a/sqlmesh/core/plan/common.py b/sqlmesh/core/plan/common.py index 929837eb7e..2a59092f67 100644 --- a/sqlmesh/core/plan/common.py +++ b/sqlmesh/core/plan/common.py @@ -1,6 +1,15 @@ from __future__ import annotations +import typing as t +import logging +from dataclasses import dataclass, field -from sqlmesh.core.snapshot import Snapshot +from sqlmesh.core.state_sync import StateReader +from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo, SnapshotNameVersion +from sqlmesh.core.snapshot.definition import Interval +from sqlmesh.utils.dag import DAG +from sqlmesh.utils.date import now_timestamp + +logger = logging.getLogger(__name__) def should_force_rebuild(old: Snapshot, new: Snapshot) -> bool: @@ -27,3 +36,182 @@ def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool: # If the partitioning hasn't changed, then we don't need to rebuild return False return True + + +@dataclass +class SnapshotIntervalClearRequest: + # affected snapshot + table_info: SnapshotTableInfo + + # which interval to clear + interval: Interval + + # which environments this snapshot is currently promoted + # note that this can be empty if the snapshot exists because its ttl has not expired + # but it is not part of any particular environment + environment_names: t.Set[str] = field(default_factory=set) + + @property + def snapshot_id(self) -> SnapshotId: + return self.table_info.snapshot_id + + @property + def sorted_environment_names(self) -> t.List[str]: + return list(sorted(self.environment_names)) + + +def identify_restatement_intervals_across_snapshot_versions( + state_reader: StateReader, + prod_restatements: t.Dict[str, Interval], + disable_restatement_models: t.Set[str], + loaded_snapshots: t.Dict[SnapshotId, Snapshot], + current_ts: t.Optional[int] = None, +) -> t.Dict[SnapshotId, SnapshotIntervalClearRequest]: + """ + Given a map of snapshot names + intervals to restate in prod: + - Look up matching snapshots (match based on name - regardless of version, to get all versions) + - For each match, also match downstream snapshots in each dev environment while filtering out models that have restatement disabled + - Return a list of all snapshots that are affected + the interval that needs to be cleared for each + + The goal here is to produce a list of intervals to invalidate across all dev snapshots so that a subsequent plan or + cadence run in those environments causes the intervals to be repopulated. + """ + if not prod_restatements: + return {} + + # Although :loaded_snapshots is sourced from RestatementStage.all_snapshots, since the only time we ever need + # to clear intervals across all environments is for prod, the :loaded_snapshots here are always from prod + prod_name_versions: t.Set[SnapshotNameVersion] = { + s.name_version for s in loaded_snapshots.values() + } + + snapshot_intervals_to_clear: t.Dict[SnapshotId, SnapshotIntervalClearRequest] = {} + + for env_summary in state_reader.get_environments_summary(): + # Fetch the full environment object one at a time to avoid loading all environments into memory at once + env = state_reader.get_environment(env_summary.name) + if not env: + logger.warning("Environment %s not found", env_summary.name) + continue + + snapshots_by_name = {s.name: s.table_info for s in env.snapshots} + + # We dont just restate matching snapshots, we also have to restate anything downstream of them + # so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev + env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots}) + + for restate_snapshot_name, interval in prod_restatements.items(): + if restate_snapshot_name not in snapshots_by_name: + # snapshot is not promoted in this environment + continue + + affected_snapshot_names = [ + x + for x in ([restate_snapshot_name] + env_dag.downstream(restate_snapshot_name)) + if x not in disable_restatement_models + ] + + for affected_snapshot_name in affected_snapshot_names: + affected_snapshot = snapshots_by_name[affected_snapshot_name] + + # Don't clear intervals for a dev snapshot if it shares the same physical version with prod. + # Otherwise, prod will be affected by what should be a dev operation + if affected_snapshot.name_version in prod_name_versions: + continue + + clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id) + if not clear_request: + clear_request = SnapshotIntervalClearRequest( + table_info=affected_snapshot, interval=interval + ) + snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request + + clear_request.environment_names |= set([env.name]) + + # snapshot_intervals_to_clear now contains the entire hierarchy of affected snapshots based + # on building the DAG for each environment and including downstream snapshots + # but, what if there are affected snapshots that arent part of any environment? + unique_snapshot_names = set(snapshot_id.name for snapshot_id in snapshot_intervals_to_clear) + + current_ts = current_ts or now_timestamp() + all_matching_non_prod_snapshots = { + s.snapshot_id: s + for s in state_reader.get_snapshots_by_names( + snapshot_names=unique_snapshot_names, current_ts=current_ts, exclude_expired=True + ) + # Don't clear intervals for a snapshot if it shares the same physical version with prod. + # Otherwise, prod will be affected by what should be a dev operation + if s.name_version not in prod_name_versions + } + + # identify the ones that we havent picked up yet, which are the ones that dont exist in any environment + if remaining_snapshot_ids := set(all_matching_non_prod_snapshots).difference( + snapshot_intervals_to_clear + ): + # these snapshot id's exist in isolation and may be related to a downstream dependency of the :prod_restatements, + # rather than directly related, so we can't simply look up the interval to clear based on :prod_restatements. + # To figure out the interval that should be cleared, we can match to the existing list based on name + # and conservatively take the widest interval that shows up + snapshot_name_to_widest_interval: t.Dict[str, Interval] = {} + for s_id, clear_request in snapshot_intervals_to_clear.items(): + current_start, current_end = snapshot_name_to_widest_interval.get( + s_id.name, clear_request.interval + ) + next_start, next_end = clear_request.interval + + next_start = min(current_start, next_start) + next_end = max(current_end, next_end) + + snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end) + + # we need to fetch full Snapshot's to get access to the SnapshotTableInfo objects + # required by StateSync.remove_intervals() + # but at this point we have minimized the list by excluding the ones that are already present in prod + # and also excluding the ones we have already matched earlier while traversing the environment DAGs + remaining_snapshots = state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids) + for remaining_snapshot_id, remaining_snapshot in remaining_snapshots.items(): + snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest( + table_info=remaining_snapshot.table_info, + interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name], + ) + + loaded_snapshots.update(remaining_snapshots) + + # for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to + # include the whole time range for that snapshot. This requires a call to state to load the full snapshot record, + # so we only do it if necessary + full_history_restatement_snapshot_ids = [ + # FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self + # however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state, + # is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present) + # So for now, these are not considered + s_id + for s_id, s in snapshot_intervals_to_clear.items() + if s.table_info.full_history_restatement_only + ] + if full_history_restatement_snapshot_ids: + # only load full snapshot records that we havent already loaded + additional_snapshots = state_reader.get_snapshots( + [ + s.snapshot_id + for s in full_history_restatement_snapshot_ids + if s.snapshot_id not in loaded_snapshots + ] + ) + + all_snapshots = loaded_snapshots | additional_snapshots + + for full_snapshot_id in full_history_restatement_snapshot_ids: + full_snapshot = all_snapshots[full_snapshot_id] + intervals_to_clear = snapshot_intervals_to_clear[full_snapshot_id] + + original_start, original_end = intervals_to_clear.interval + + # get_removal_interval() widens intervals if necessary + new_interval = full_snapshot.get_removal_interval( + start=original_start, end=original_end + ) + + intervals_to_clear.interval = new_interval + + return snapshot_intervals_to_clear diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 03b0b64016..9263a08631 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -22,7 +22,7 @@ from sqlmesh.core.console import Console, get_console from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements from sqlmesh.core.macros import RuntimeStage -from sqlmesh.core.snapshot.definition import Interval, to_view_mapping +from sqlmesh.core.snapshot.definition import to_view_mapping from sqlmesh.core.plan import stages from sqlmesh.core.plan.definition import EvaluatablePlan from sqlmesh.core.scheduler import Scheduler @@ -33,17 +33,15 @@ SnapshotIntervals, SnapshotId, SnapshotInfoLike, - SnapshotTableInfo, SnapshotCreationFailedError, - SnapshotNameVersion, ) from sqlmesh.utils import to_snake_case from sqlmesh.core.state_sync import StateSync +from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions from sqlmesh.utils import CorrelationId from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.errors import PlanError, SQLMeshError -from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import now +from sqlmesh.utils.date import now, to_timestamp logger = logging.getLogger(__name__) @@ -298,11 +296,16 @@ def visit_restatement_stage( # # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod snapshot_intervals_to_restate.update( - self._restatement_intervals_across_all_environments( - prod_restatements=plan.restatements, - disable_restatement_models=plan.disabled_restatement_models, - loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, - ) + { + (s.table_info, s.interval) + for s in identify_restatement_intervals_across_snapshot_versions( + state_reader=self.state_sync, + prod_restatements=plan.restatements, + disable_restatement_models=plan.disabled_restatement_models, + loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, + current_ts=to_timestamp(plan.execution_time or now()), + ).values() + } ) self.state_sync.remove_intervals( @@ -422,97 +425,6 @@ def _demote_snapshots( on_complete=on_complete, ) - def _restatement_intervals_across_all_environments( - self, - prod_restatements: t.Dict[str, Interval], - disable_restatement_models: t.Set[str], - loaded_snapshots: t.Dict[SnapshotId, Snapshot], - ) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]: - """ - Given a map of snapshot names + intervals to restate in prod: - - Look up matching snapshots across all environments (match based on name - regardless of version) - - For each match, also match downstream snapshots while filtering out models that have restatement disabled - - Return all matches mapped to the intervals of the prod snapshot being restated - - The goal here is to produce a list of intervals to invalidate across all environments so that a cadence - run in those environments causes the intervals to be repopulated - """ - if not prod_restatements: - return set() - - prod_name_versions: t.Set[SnapshotNameVersion] = { - s.name_version for s in loaded_snapshots.values() - } - - snapshots_to_restate: t.Dict[SnapshotId, t.Tuple[SnapshotTableInfo, Interval]] = {} - - for env_summary in self.state_sync.get_environments_summary(): - # Fetch the full environment object one at a time to avoid loading all environments into memory at once - env = self.state_sync.get_environment(env_summary.name) - if not env: - logger.warning("Environment %s not found", env_summary.name) - continue - - keyed_snapshots = {s.name: s.table_info for s in env.snapshots} - - # We dont just restate matching snapshots, we also have to restate anything downstream of them - # so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev - env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots}) - - for restatement, intervals in prod_restatements.items(): - if restatement not in keyed_snapshots: - continue - affected_snapshot_names = [ - x - for x in ([restatement] + env_dag.downstream(restatement)) - if x not in disable_restatement_models - ] - snapshots_to_restate.update( - { - keyed_snapshots[a].snapshot_id: (keyed_snapshots[a], intervals) - for a in affected_snapshot_names - # Don't restate a snapshot if it shares the version with a snapshot in prod - if keyed_snapshots[a].name_version not in prod_name_versions - } - ) - - # for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to - # include the whole time range for that snapshot. This requires a call to state to load the full snapshot record, - # so we only do it if necessary - full_history_restatement_snapshot_ids = [ - # FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self - # however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state, - # is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present) - # So for now, these are not considered - s_id - for s_id, s in snapshots_to_restate.items() - if s[0].full_history_restatement_only - ] - if full_history_restatement_snapshot_ids: - # only load full snapshot records that we havent already loaded - additional_snapshots = self.state_sync.get_snapshots( - [ - s.snapshot_id - for s in full_history_restatement_snapshot_ids - if s.snapshot_id not in loaded_snapshots - ] - ) - - all_snapshots = loaded_snapshots | additional_snapshots - - for full_snapshot_id in full_history_restatement_snapshot_ids: - full_snapshot = all_snapshots[full_snapshot_id] - _, original_intervals = snapshots_to_restate[full_snapshot_id] - original_start, original_end = original_intervals - - # get_removal_interval() widens intervals if necessary - new_intervals = full_snapshot.get_removal_interval( - start=original_start, end=original_end - ) - snapshots_to_restate[full_snapshot_id] = (full_snapshot.table_info, new_intervals) - - return set(snapshots_to_restate.values()) - def _update_intervals_for_new_snapshots(self, snapshots: t.Collection[Snapshot]) -> None: snapshots_intervals: t.List[SnapshotIntervals] = [] for snapshot in snapshots: diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 948882c4dc..c042850c1f 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -4238,6 +4238,180 @@ def test_prod_restatement_plan_missing_model_in_dev( ) +def test_prod_restatement_plan_includes_related_unpromoted_snapshots(tmp_path: Path): + """ + Scenario: + - I have models A <- B in prod + - I have models A <- B <- C in dev + - Both B and C have gone through a few iterations in dev so multiple snapshot versions exist + for them but not all of them are promoted / active + - I restate A in prod + + Outcome: + - Intervals should be cleared for all of the versions of B and C, regardless + of if they are active in any particular environment, in case they ever get made + active + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + (models_dir / "a.sql").write_text(""" + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select 1 as a, now() as ts; + """) + + (models_dir / "b.sql").write_text(""" + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select a, ts from test.a + """) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01")) + ctx = Context(paths=[tmp_path], config=config) + + def _all_snapshots() -> t.Dict[SnapshotId, Snapshot]: + all_snapshot_ids = [ + SnapshotId(name=name, identifier=identifier) + for (name, identifier) in ctx.state_sync.state_sync.engine_adapter.fetchall( # type: ignore + "select name, identifier from sqlmesh._snapshots" + ) + ] + return ctx.state_sync.get_snapshots(all_snapshot_ids) + + # plan + apply prod + ctx.plan(environment="prod", auto_apply=True) + assert len(_all_snapshots()) == 2 + + # create dev with new version of B + (models_dir / "b.sql").write_text(""" + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select a, ts, 'b dev 1' as change from test.a + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + assert len(_all_snapshots()) == 3 + + # update B (new version) and create C + (models_dir / "b.sql").write_text(""" + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select a, ts, 'b dev 2' as change from test.a + """) + + (models_dir / "c.sql").write_text(""" + MODEL ( + name test.c, + kind FULL, + cron '@daily' + ); + + select *, 'c initial' as val from test.b + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + assert len(_all_snapshots()) == 5 + + # update C (new version), create D (unrelated) + (models_dir / "c.sql").write_text(""" + MODEL ( + name test.c, + kind FULL, + cron '@daily' + ); + + select *, 'c updated' as val from test.b + """) + + (models_dir / "d.sql").write_text(""" + MODEL ( + name test.d, + cron '@daily' + ); + + select 1 as unrelated + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + all_snapshots_prior_to_restatement = _all_snapshots() + assert len(all_snapshots_prior_to_restatement) == 7 + + def _snapshot_instances(lst: t.Dict[SnapshotId, Snapshot], name_match: str) -> t.List[Snapshot]: + return [s for s_id, s in lst.items() if name_match in s_id.name] + + # verify initial state + + # 1 instance of A (prod) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"a"')) == 1 + + # 3 instances of B (original in prod + 2 updates in dev) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"b"')) == 3 + + # 2 instances of C (initial + update in dev) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"c"')) == 2 + + # 1 instance of D (initial - dev) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"d"')) == 1 + + # restate A in prod + ctx.plan(environment="prod", restate_models=['"memory"."test"."a"'], auto_apply=True) + + all_snapshots_after_restatement = _all_snapshots() + + # All versions of B and C in dev should have had intervals cleared + # D in dev should not be touched and A + B in prod shoud also not be touched + a = _snapshot_instances(all_snapshots_after_restatement, '"a"') + assert len(a) == 1 + + b = _snapshot_instances(all_snapshots_after_restatement, '"b"') + # the 1 B instance in prod should be populated and 2 in dev (1 active) should be cleared + assert len(b) == 3 + assert len([s for s in b if not s.intervals]) == 2 + + c = _snapshot_instances(all_snapshots_after_restatement, '"c"') + # the 2 instances of C in dev (1 active) should be cleared + assert len(c) == 2 + assert len([s for s in c if not s.intervals]) == 2 + + d = _snapshot_instances(all_snapshots_after_restatement, '"d"') + # D should not be touched since it's in no way downstream of A in prod + assert len(d) == 1 + assert d[0].intervals + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_dev_restatement_of_prod_model(init_and_plan_context: t.Callable): context, plan = init_and_plan_context("examples/sushi") From 924e59e1c06e371b75094b8e13a47683dcab5a1f Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Tue, 9 Sep 2025 22:47:37 +0000 Subject: [PATCH 2/3] PR feedback --- sqlmesh/core/plan/common.py | 22 ++++++++-------------- sqlmesh/core/plan/evaluator.py | 6 ++++-- sqlmesh/core/snapshot/__init__.py | 1 + sqlmesh/core/snapshot/definition.py | 23 ++++++++++++++++++++++- sqlmesh/core/state_sync/base.py | 3 ++- sqlmesh/core/state_sync/cache.py | 3 ++- sqlmesh/core/state_sync/db/facade.py | 3 ++- sqlmesh/core/state_sync/db/interval.py | 12 +++++++----- sqlmesh/core/state_sync/db/snapshot.py | 11 ++++++++--- tests/core/test_snapshot.py | 25 +++++++++++++++++++++++++ 10 files changed, 81 insertions(+), 28 deletions(-) diff --git a/sqlmesh/core/plan/common.py b/sqlmesh/core/plan/common.py index 2a59092f67..aabc45c0b2 100644 --- a/sqlmesh/core/plan/common.py +++ b/sqlmesh/core/plan/common.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from sqlmesh.core.state_sync import StateReader -from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo, SnapshotNameVersion +from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdAndVersion, SnapshotNameVersion from sqlmesh.core.snapshot.definition import Interval from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import now_timestamp @@ -41,7 +41,7 @@ def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool: @dataclass class SnapshotIntervalClearRequest: # affected snapshot - table_info: SnapshotTableInfo + snapshot: SnapshotIdAndVersion # which interval to clear interval: Interval @@ -53,7 +53,7 @@ class SnapshotIntervalClearRequest: @property def snapshot_id(self) -> SnapshotId: - return self.table_info.snapshot_id + return self.snapshot.snapshot_id @property def sorted_environment_names(self) -> t.List[str]: @@ -122,7 +122,7 @@ def identify_restatement_intervals_across_snapshot_versions( clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id) if not clear_request: clear_request = SnapshotIntervalClearRequest( - table_info=affected_snapshot, interval=interval + snapshot=affected_snapshot.id_and_version, interval=interval ) snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request @@ -164,19 +164,13 @@ def identify_restatement_intervals_across_snapshot_versions( snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end) - # we need to fetch full Snapshot's to get access to the SnapshotTableInfo objects - # required by StateSync.remove_intervals() - # but at this point we have minimized the list by excluding the ones that are already present in prod - # and also excluding the ones we have already matched earlier while traversing the environment DAGs - remaining_snapshots = state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids) - for remaining_snapshot_id, remaining_snapshot in remaining_snapshots.items(): + for remaining_snapshot_id in remaining_snapshot_ids: + remaining_snapshot = all_matching_non_prod_snapshots[remaining_snapshot_id] snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest( - table_info=remaining_snapshot.table_info, + snapshot=remaining_snapshot, interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name], ) - loaded_snapshots.update(remaining_snapshots) - # for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to # include the whole time range for that snapshot. This requires a call to state to load the full snapshot record, # so we only do it if necessary @@ -187,7 +181,7 @@ def identify_restatement_intervals_across_snapshot_versions( # So for now, these are not considered s_id for s_id, s in snapshot_intervals_to_clear.items() - if s.table_info.full_history_restatement_only + if s.snapshot.kind_name and s.snapshot.kind_name.full_history_restatement_only ] if full_history_restatement_snapshot_ids: # only load full snapshot records that we havent already loaded diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 9263a08631..79053e018b 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -287,7 +287,9 @@ def visit_audit_only_run_stage( def visit_restatement_stage( self, stage: stages.RestatementStage, plan: EvaluatablePlan ) -> None: - snapshot_intervals_to_restate = {(s, i) for s, i in stage.snapshot_intervals.items()} + snapshot_intervals_to_restate = { + (s.id_and_version, i) for s, i in stage.snapshot_intervals.items() + } # Restating intervals on prod plans should mean that the intervals are cleared across # all environments, not just the version currently in prod @@ -297,7 +299,7 @@ def visit_restatement_stage( # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod snapshot_intervals_to_restate.update( { - (s.table_info, s.interval) + (s.snapshot, s.interval) for s in identify_restatement_intervals_across_snapshot_versions( state_reader=self.state_sync, prod_restatements=plan.restatements, diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index 32842cc4b2..65e5c2a822 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -11,6 +11,7 @@ SnapshotId as SnapshotId, SnapshotIdBatch as SnapshotIdBatch, SnapshotIdLike as SnapshotIdLike, + SnapshotIdAndVersionLike as SnapshotIdAndVersionLike, SnapshotInfoLike as SnapshotInfoLike, SnapshotIntervals as SnapshotIntervals, SnapshotNameVersion as SnapshotNameVersion, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index c124c2098f..7c3534f746 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -587,6 +587,17 @@ def name_version(self) -> SnapshotNameVersion: """Returns the name and version of the snapshot.""" return SnapshotNameVersion(name=self.name, version=self.version) + @property + def id_and_version(self) -> SnapshotIdAndVersion: + return SnapshotIdAndVersion( + name=self.name, + kind_name=self.kind_name, + identifier=self.identifier, + version=self.version, + dev_version=self.dev_version, + fingerprint=self.fingerprint, + ) + class SnapshotIdAndVersion(PydanticModel): """A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table @@ -595,6 +606,7 @@ class SnapshotIdAndVersion(PydanticModel): name: str version: str + kind_name: t.Optional[ModelKindName] = None dev_version_: t.Optional[str] = Field(alias="dev_version") identifier: str fingerprint_: t.Union[str, SnapshotFingerprint] = Field(alias="fingerprint") @@ -603,6 +615,10 @@ class SnapshotIdAndVersion(PydanticModel): def snapshot_id(self) -> SnapshotId: return SnapshotId(name=self.name, identifier=self.identifier) + @property + def id_and_version(self) -> SnapshotIdAndVersion: + return self + @property def name_version(self) -> SnapshotNameVersion: return SnapshotNameVersion(name=self.name, version=self.version) @@ -1424,6 +1440,10 @@ def name_version(self) -> SnapshotNameVersion: """Returns the name and version of the snapshot.""" return SnapshotNameVersion(name=self.name, version=self.version) + @property + def id_and_version(self) -> SnapshotIdAndVersion: + return self.table_info.id_and_version + @property def disable_restatement(self) -> bool: """Is restatement disabled for the node""" @@ -1494,7 +1514,8 @@ class SnapshotTableCleanupTask(PydanticModel): dev_table_only: bool -SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot] +SnapshotIdLike = t.Union[SnapshotId, SnapshotIdAndVersion, SnapshotTableInfo, Snapshot] +SnapshotIdAndVersionLike = t.Union[SnapshotIdAndVersion, SnapshotTableInfo, Snapshot] SnapshotInfoLike = t.Union[SnapshotTableInfo, Snapshot] SnapshotNameVersionLike = t.Union[ SnapshotNameVersion, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 4219472cb6..450d6f7408 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -19,6 +19,7 @@ Snapshot, SnapshotId, SnapshotIdLike, + SnapshotIdAndVersionLike, SnapshotInfoLike, SnapshotTableCleanupTask, SnapshotTableInfo, @@ -390,7 +391,7 @@ def remove_state(self, including_backup: bool = False) -> None: @abc.abstractmethod def remove_intervals( self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], remove_shared_versions: bool = False, ) -> None: """Remove an interval from a list of snapshots and sync it to the store. diff --git a/sqlmesh/core/state_sync/cache.py b/sqlmesh/core/state_sync/cache.py index 8aa5054e13..3de4e7bf51 100644 --- a/sqlmesh/core/state_sync/cache.py +++ b/sqlmesh/core/state_sync/cache.py @@ -7,6 +7,7 @@ Snapshot, SnapshotId, SnapshotIdLike, + SnapshotIdAndVersionLike, SnapshotInfoLike, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals @@ -128,7 +129,7 @@ def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotInterv def remove_intervals( self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], remove_shared_versions: bool = False, ) -> None: for s, _ in snapshot_intervals: diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 93c4b87e9e..29fc9f1740 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -31,6 +31,7 @@ SnapshotIdAndVersion, SnapshotId, SnapshotIdLike, + SnapshotIdAndVersionLike, SnapshotInfoLike, SnapshotIntervals, SnapshotNameVersion, @@ -407,7 +408,7 @@ def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotInterv @transactional() def remove_intervals( self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], remove_shared_versions: bool = False, ) -> None: self.interval_state.remove_intervals(snapshot_intervals, remove_shared_versions) diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py index 75f475b75b..b15ad2d57b 100644 --- a/sqlmesh/core/state_sync/db/interval.py +++ b/sqlmesh/core/state_sync/db/interval.py @@ -15,10 +15,10 @@ from sqlmesh.core.snapshot import ( SnapshotIntervals, SnapshotIdLike, + SnapshotIdAndVersionLike, SnapshotNameVersionLike, SnapshotTableCleanupTask, SnapshotNameVersion, - SnapshotInfoLike, Snapshot, ) from sqlmesh.core.snapshot.definition import Interval @@ -68,11 +68,11 @@ def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotInterv def remove_intervals( self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], remove_shared_versions: bool = False, ) -> None: intervals_to_remove: t.Sequence[ - t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval] + t.Tuple[t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], Interval] ] = snapshot_intervals if remove_shared_versions: name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} @@ -431,7 +431,9 @@ def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask] def _intervals_to_df( - snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]], + snapshot_intervals: t.Sequence[ + t.Tuple[t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], Interval] + ], is_dev: bool, is_removed: bool, ) -> pd.DataFrame: @@ -451,7 +453,7 @@ def _intervals_to_df( def _interval_to_df( - snapshot: t.Union[SnapshotInfoLike, SnapshotIntervals], + snapshot: t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], start_ts: int, end_ts: int, is_dev: bool = False, diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 1904e51c55..4a8b2c44c5 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -337,6 +337,7 @@ def get_snapshots_by_names( name=name, identifier=identifier, version=version, + kind_name=kind_name or None, dev_version=dev_version, fingerprint=fingerprint, ) @@ -344,9 +345,11 @@ def get_snapshots_by_names( snapshot_names=snapshot_names, batch_size=self.SNAPSHOT_BATCH_SIZE, ) - for name, identifier, version, dev_version, fingerprint in fetchall( + for name, identifier, version, kind_name, dev_version, fingerprint in fetchall( self.engine_adapter, - exp.select("name", "identifier", "version", "dev_version", "fingerprint") + exp.select( + "name", "identifier", "version", "kind_name", "dev_version", "fingerprint" + ) .from_(self.snapshots_table) .where(where) .and_(unexpired_expr), @@ -661,6 +664,7 @@ def _get_snapshots_with_same_version( "name", "identifier", "version", + "kind_name", "dev_version", "fingerprint", ) @@ -677,10 +681,11 @@ def _get_snapshots_with_same_version( name=name, identifier=identifier, version=version, + kind_name=kind_name or None, dev_version=dev_version, fingerprint=SnapshotFingerprint.parse_raw(fingerprint), ) - for name, identifier, version, dev_version, fingerprint in snapshot_rows + for name, identifier, version, kind_name, dev_version, fingerprint in snapshot_rows ] diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index eff3ad2b60..4aa8e4bc0e 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -3567,3 +3567,28 @@ def test_snapshot_id_and_version_fingerprint_lazy_init(): assert isinstance(snapshot.fingerprint_, SnapshotFingerprint) assert snapshot.fingerprint == fingerprint + + +def test_snapshot_id_and_version_optional_kind_name(): + snapshot = SnapshotIdAndVersion( + name="a", + identifier="1234", + version="2345", + dev_version=None, + fingerprint="", + ) + + assert snapshot.kind_name is None + + snapshot = SnapshotIdAndVersion( + name="a", + identifier="1234", + version="2345", + kind_name="INCREMENTAL_UNMANAGED", + dev_version=None, + fingerprint="", + ) + + assert snapshot.kind_name + assert snapshot.kind_name.is_incremental_unmanaged + assert snapshot.kind_name.full_history_restatement_only From c7e3fa56234fdbabfaffedd433ec1a5046b1e063 Mon Sep 17 00:00:00 2001 From: Erin Drummond Date: Wed, 10 Sep 2025 20:50:50 +0000 Subject: [PATCH 3/3] Implement ModelKindMixin on SnapshotIdAndVersion --- sqlmesh/core/plan/common.py | 2 +- sqlmesh/core/snapshot/definition.py | 8 ++++++-- tests/core/test_snapshot.py | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/sqlmesh/core/plan/common.py b/sqlmesh/core/plan/common.py index aabc45c0b2..4ae8a3112c 100644 --- a/sqlmesh/core/plan/common.py +++ b/sqlmesh/core/plan/common.py @@ -181,7 +181,7 @@ def identify_restatement_intervals_across_snapshot_versions( # So for now, these are not considered s_id for s_id, s in snapshot_intervals_to_clear.items() - if s.snapshot.kind_name and s.snapshot.kind_name.full_history_restatement_only + if s.snapshot.full_history_restatement_only ] if full_history_restatement_snapshot_ids: # only load full snapshot records that we havent already loaded diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 7c3534f746..f05c2fb7ab 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -599,14 +599,14 @@ def id_and_version(self) -> SnapshotIdAndVersion: ) -class SnapshotIdAndVersion(PydanticModel): +class SnapshotIdAndVersion(PydanticModel, ModelKindMixin): """A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table without the overhead of parsing the full snapshot payload and fetching intervals. """ name: str version: str - kind_name: t.Optional[ModelKindName] = None + kind_name_: t.Optional[ModelKindName] = Field(default=None, alias="kind_name") dev_version_: t.Optional[str] = Field(alias="dev_version") identifier: str fingerprint_: t.Union[str, SnapshotFingerprint] = Field(alias="fingerprint") @@ -634,6 +634,10 @@ def fingerprint(self) -> SnapshotFingerprint: def dev_version(self) -> str: return self.dev_version_ or self.fingerprint.to_version() + @property + def model_kind_name(self) -> t.Optional[ModelKindName]: + return self.kind_name_ + class Snapshot(PydanticModel, SnapshotInfoMixin): """A snapshot represents a node at a certain point in time. diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 4aa8e4bc0e..c769991b86 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -3578,7 +3578,7 @@ def test_snapshot_id_and_version_optional_kind_name(): fingerprint="", ) - assert snapshot.kind_name is None + assert snapshot.model_kind_name is None snapshot = SnapshotIdAndVersion( name="a", @@ -3589,6 +3589,6 @@ def test_snapshot_id_and_version_optional_kind_name(): fingerprint="", ) - assert snapshot.kind_name - assert snapshot.kind_name.is_incremental_unmanaged - assert snapshot.kind_name.full_history_restatement_only + assert snapshot.model_kind_name + assert snapshot.is_incremental_unmanaged + assert snapshot.full_history_restatement_only