From 050d94e1c2247bdbbeab2dc09e72bae6e991913d Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Wed, 13 Aug 2025 09:07:21 -0700 Subject: [PATCH 1/3] Chore!: Optimize snapshot unpausing --- sqlmesh/core/state_sync/db/facade.py | 2 +- sqlmesh/core/state_sync/db/migrator.py | 2 +- sqlmesh/core/state_sync/db/snapshot.py | 145 +++++++----------- sqlmesh/core/state_sync/db/utils.py | 15 ++ .../v0090_add_forward_only_column.py | 100 ++++++++++++ tests/core/state_sync/test_state_sync.py | 122 +-------------- 6 files changed, 177 insertions(+), 209 deletions(-) create mode 100644 sqlmesh/migrations/v0090_add_forward_only_column.py diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 779add1cca..858b1aa072 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -256,7 +256,7 @@ def finalize(self, environment: Environment) -> None: def unpause_snapshots( self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike ) -> None: - self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state) + self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt) def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: self.environment_state.invalidate_environment(name, protect_prod) diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py index 405c0ea667..ca89668763 100644 --- a/sqlmesh/core/state_sync/db/migrator.py +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -396,7 +396,7 @@ def _migrate_environment_rows( if updated_prod_environment: try: self.snapshot_state.unpause_snapshots( - updated_prod_environment.snapshots, now_timestamp(), self.interval_state + updated_prod_environment.snapshots, now_timestamp() ) except Exception: logger.warning("Failed to unpause migrated snapshots", exc_info=True) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 30e0de00f2..4ea4a837fd 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -10,6 +10,7 @@ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.state_sync.db.utils import ( + snapshot_name_filter, snapshot_name_version_filter, snapshot_id_filter, fetchone, @@ -32,15 +33,13 @@ SnapshotChangeCategory, ) from sqlmesh.utils.migration import index_text_type, blob_text_type -from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp +from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp from sqlmesh.utils.pydantic import PydanticModel from sqlmesh.utils import unique if t.TYPE_CHECKING: import pandas as pd - from sqlmesh.core.state_sync.db.interval import IntervalState - logger = logging.getLogger(__name__) @@ -70,6 +69,7 @@ def __init__( "unpaused_ts": exp.DataType.build("bigint"), "ttl_ms": exp.DataType.build("bigint"), "unrestorable": exp.DataType.build("boolean"), + "forward_only": exp.DataType.build("boolean"), } self._auto_restatement_columns_to_types = { @@ -112,84 +112,48 @@ def unpause_snapshots( self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike, - interval_state: IntervalState, ) -> None: - """Unpauses given snapshots while pausing all other snapshots that share the same version. - - Args: - snapshots: The snapshots to unpause. - unpaused_dt: The timestamp to unpause the snapshots at. - interval_state: The interval state to use to remove intervals when needed. - """ - current_ts = now() - - target_snapshot_ids = {s.snapshot_id for s in snapshots} - same_version_snapshots = self._get_snapshots_with_same_version( - snapshots, lock_for_update=True - ) - target_snapshots_by_version = { - (s.name, s.version): s - for s in same_version_snapshots - if s.snapshot_id in target_snapshot_ids - } + unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[str]] = defaultdict(list) - unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list) - paused_snapshots: t.List[SnapshotId] = [] - unrestorable_snapshots: t.List[SnapshotId] = [] - - for snapshot in same_version_snapshots: - is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids - if is_target_snapshot and not snapshot.unpaused_ts: - logger.info("Unpausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(unpaused_dt) - assert snapshot.unpaused_ts is not None - unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id) - elif not is_target_snapshot: - target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)] - if ( - target_snapshot.normalized_effective_from_ts - and not target_snapshot.disable_restatement - ): - # Making sure that there are no overlapping intervals. - effective_from_ts = target_snapshot.normalized_effective_from_ts - logger.info( - "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s", - target_snapshot.effective_from, - snapshot.snapshot_id, - target_snapshot.snapshot_id, - ) - full_snapshot = snapshot.full_snapshot - interval_state.remove_intervals( - [ - ( - full_snapshot, - full_snapshot.get_removal_interval(effective_from_ts, current_ts), - ) - ] - ) - - if snapshot.unpaused_ts: - logger.info("Pausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(None) - paused_snapshots.append(snapshot.snapshot_id) + for snapshot in snapshots: + # We need to mark all other snapshots that have opposite forward only status as unrestorable + unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append( + snapshot.name + ) - if not snapshot.unrestorable and ( - (target_snapshot.is_forward_only and not snapshot.is_forward_only) - or (snapshot.is_forward_only and not target_snapshot.is_forward_only) - ): - logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id) - snapshot.unrestorable = True - unrestorable_snapshots.append(snapshot.snapshot_id) + updated_ts = now_timestamp() + unpaused_ts = to_timestamp(unpaused_dt) - if unpaused_snapshots: - for unpaused_ts, snapshot_ids in unpaused_snapshots.items(): - self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts) + # Pause all snapshots with target names first + for where in snapshot_name_filter( + [s.name for s in snapshots], + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.update_table( + self.snapshots_table, + {"unpaused_ts": None, "updated_ts": updated_ts}, + where=where, + ) - if paused_snapshots: - self._update_snapshots(paused_snapshots, unpaused_ts=None) + # Now unpause the target snapshots + self._update_snapshots( + [s.snapshot_id for s in snapshots], + unpaused_ts=unpaused_ts, + updated_ts=updated_ts, + ) - if unrestorable_snapshots: - self._update_snapshots(unrestorable_snapshots, unrestorable=True) + # Mark unrestorable snapshots + for forward_only, snapshot_names in unrestorable_snapshots_by_forward_only.items(): + forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only)) + for where in snapshot_name_filter( + snapshot_names, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.update_table( + self.snapshots_table, + {"unrestorable": True, "updated_ts": updated_ts}, + where=forward_only_exp.and_(where), + ) def get_expired_snapshots( self, @@ -414,7 +378,8 @@ def _update_snapshots( **kwargs: t.Any, ) -> None: properties = kwargs - properties["updated_ts"] = now_timestamp() + if "updated_ts" not in properties: + properties["updated_ts"] = now_timestamp() for where in snapshot_id_filter( self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE @@ -466,6 +431,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: updated_ts, unpaused_ts, unrestorable, + forward_only, next_auto_restatement_ts, ) in fetchall(self.engine_adapter, query): snapshot = parse_snapshot( @@ -473,6 +439,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: updated_ts=updated_ts, unpaused_ts=unpaused_ts, unrestorable=unrestorable, + forward_only=forward_only, next_auto_restatement_ts=next_auto_restatement_ts, ) snapshot_id = snapshot.snapshot_id @@ -502,6 +469,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: "updated_ts", "unpaused_ts", "unrestorable", + "forward_only", "next_auto_restatement_ts", ) .from_(exp.to_table(self.snapshots_table).as_("snapshots")) @@ -528,6 +496,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: updated_ts, unpaused_ts, unrestorable, + forward_only, next_auto_restatement_ts, ) in fetchall(self.engine_adapter, query): snapshot_id = SnapshotId(name=name, identifier=identifier) @@ -535,6 +504,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: snapshot.updated_ts = updated_ts snapshot.unpaused_ts = unpaused_ts snapshot.unrestorable = unrestorable + snapshot.forward_only = forward_only snapshot.next_auto_restatement_ts = next_auto_restatement_ts cached_snapshots_in_state.add(snapshot_id) @@ -568,6 +538,7 @@ def _get_snapshots_expressions( "snapshots.updated_ts", "snapshots.unpaused_ts", "snapshots.unrestorable", + "snapshots.forward_only", "auto_restatements.next_auto_restatement_ts", ) .from_(exp.to_table(self.snapshots_table).as_("snapshots")) @@ -623,6 +594,7 @@ def _get_snapshots_with_same_version( "updated_ts", "unpaused_ts", "unrestorable", + "forward_only", ) .from_(exp.to_table(self.snapshots_table).as_("snapshots")) .where(where) @@ -640,9 +612,10 @@ def _get_snapshots_with_same_version( updated_ts=updated_ts, unpaused_ts=unpaused_ts, unrestorable=unrestorable, + forward_only=forward_only, snapshot=snapshot, ) - for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows + for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows ] @@ -651,6 +624,7 @@ def parse_snapshot( updated_ts: int, unpaused_ts: t.Optional[int], unrestorable: bool, + forward_only: bool, next_auto_restatement_ts: t.Optional[int], ) -> Snapshot: return Snapshot( @@ -659,6 +633,7 @@ def parse_snapshot( "updated_ts": updated_ts, "unpaused_ts": unpaused_ts, "unrestorable": unrestorable, + "forward_only": forward_only, "next_auto_restatement_ts": next_auto_restatement_ts, } ) @@ -673,6 +648,7 @@ def _snapshot_to_json(snapshot: Snapshot) -> str: "updated_ts", "unpaused_ts", "unrestorable", + "forward_only", "next_auto_restatement_ts", } ) @@ -693,6 +669,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: "unpaused_ts": snapshot.unpaused_ts, "ttl_ms": snapshot.ttl_ms, "unrestorable": snapshot.unrestorable, + "forward_only": snapshot.forward_only, } for snapshot in snapshots ] @@ -762,19 +739,10 @@ def full_snapshot(self) -> Snapshot: "updated_ts": self.updated_ts, "unpaused_ts": self.unpaused_ts, "unrestorable": self.unrestorable, + "forward_only": self.forward_only, } ) - def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None: - """Sets the timestamp for when this snapshot was unpaused. - - Args: - unpaused_dt: The datetime object of when this snapshot was unpaused. - """ - self.unpaused_ts = ( - to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None - ) - @classmethod def from_snapshot_record( cls, @@ -785,6 +753,7 @@ def from_snapshot_record( updated_ts: int, unpaused_ts: t.Optional[int], unrestorable: bool, + forward_only: bool, snapshot: str, ) -> SharedVersionSnapshot: raw_snapshot = json.loads(snapshot) @@ -803,5 +772,5 @@ def from_snapshot_record( disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False), effective_from=raw_snapshot.get("effective_from"), raw_snapshot=raw_snapshot, - forward_only=raw_snapshot.get("forward_only", False), + forward_only=forward_only, ) diff --git a/sqlmesh/core/state_sync/db/utils.py b/sqlmesh/core/state_sync/db/utils.py index e5ffda6486..87c259f5d6 100644 --- a/sqlmesh/core/state_sync/db/utils.py +++ b/sqlmesh/core/state_sync/db/utils.py @@ -22,6 +22,21 @@ T = t.TypeVar("T") +def snapshot_name_filter( + snapshot_names: t.Iterable[str], + batch_size: int, + alias: t.Optional[str] = None, +) -> t.Iterator[exp.Condition]: + names = sorted(snapshot_names) + + if not names: + yield exp.false() + else: + batches = create_batches(names, batch_size=batch_size) + for names in batches: + yield exp.column("name", table=alias).isin(*names) + + def snapshot_id_filter( engine_adapter: EngineAdapter, snapshot_ids: t.Iterable[SnapshotIdLike], diff --git a/sqlmesh/migrations/v0090_add_forward_only_column.py b/sqlmesh/migrations/v0090_add_forward_only_column.py new file mode 100644 index 0000000000..32efc14eed --- /dev/null +++ b/sqlmesh/migrations/v0090_add_forward_only_column.py @@ -0,0 +1,100 @@ +"""Add forward_only column to the snapshots table.""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate(state_sync, **kwargs): # type: ignore + import pandas as pd + + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + alter_table_exp = exp.Alter( + this=exp.to_table(snapshots_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("forward_only"), + kind=exp.DataType.build("boolean"), + ) + ], + ) + engine_adapter.execute(alter_table_exp) + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + forward_only, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + "forward_only", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + + forward_only = parsed_snapshot.get("forward_only") + if forward_only is None: + forward_only = parsed_snapshot.get("change_category") == 3 + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + "forward_only": forward_only, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + "forward_only": exp.DataType.build("boolean"), + }, + ) diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index d8e96a1f35..d61907a5aa 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -1888,50 +1888,6 @@ def test_unpause_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t. assert not actual_snapshots[new_snapshot.snapshot_id].unrestorable -def test_unpause_snapshots_hourly(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot = make_snapshot( - SqlModel( - name="test_snapshot", - query=parse_one("select 1, ds"), - cron="@hourly", - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.version = "a" - - assert not snapshot.unpaused_ts - state_sync.push_snapshots([snapshot]) - - # Unpaused timestamp not aligned with cron - unpaused_dt = "2022-01-01 01:22:33" - state_sync.unpause_snapshots([snapshot], unpaused_dt) - - actual_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] - assert actual_snapshot.unpaused_ts - assert actual_snapshot.unpaused_ts == to_timestamp("2022-01-01 01:00:00") - - new_snapshot = make_snapshot( - SqlModel( - name="test_snapshot", - query=parse_one("select 2, ds"), - cron="@daily", - interval_unit="hour", - ) - ) - new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) - new_snapshot.version = "a" - - assert not new_snapshot.unpaused_ts - state_sync.push_snapshots([new_snapshot]) - state_sync.unpause_snapshots([new_snapshot], unpaused_dt) - - actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot]) - assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts - assert actual_snapshots[new_snapshot.snapshot_id].unpaused_ts == to_timestamp( - "2022-01-01 01:00:00" - ) - - def test_unrestorable_snapshot(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): snapshot = make_snapshot( SqlModel( @@ -2037,81 +1993,6 @@ def test_unrestorable_snapshot_target_not_forward_only( assert not actual_snapshots[updated_snapshot.snapshot_id].unrestorable -def test_unpause_snapshots_remove_intervals( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - snapshot = make_snapshot( - SqlModel( - name="test_snapshot", - query=parse_one("select 1, ds"), - cron="@daily", - ), - version="a", - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.version = "a" - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2023-01-01", "2023-01-05") - - new_snapshot = make_snapshot( - SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily"), - version="a", - ) - new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) - new_snapshot.version = "a" - new_snapshot.effective_from = "2023-01-03" - state_sync.push_snapshots([new_snapshot]) - state_sync.add_interval(snapshot, "2023-01-06", "2023-01-06") - state_sync.unpause_snapshots([new_snapshot], "2023-01-06") - - actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot]) - assert actual_snapshots[new_snapshot.snapshot_id].intervals == [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), - ] - assert actual_snapshots[snapshot.snapshot_id].intervals == [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), - ] - - -def test_unpause_snapshots_remove_intervals_disabled_restatement( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - kind = dict(name="INCREMENTAL_BY_TIME_RANGE", time_column="ds", disable_restatement=True) - snapshot = make_snapshot( - SqlModel( - name="test_snapshot", - query=parse_one("select 1, ds"), - cron="@daily", - kind=kind, - ), - version="a", - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.version = "a" - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2023-01-01", "2023-01-05") - - new_snapshot = make_snapshot( - SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily", kind=kind), - version="a", - ) - new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) - new_snapshot.version = "a" - new_snapshot.effective_from = "2023-01-03" - state_sync.push_snapshots([new_snapshot]) - state_sync.add_interval(snapshot, "2023-01-06", "2023-01-06") - state_sync.unpause_snapshots([new_snapshot], "2023-01-06") - - actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot]) - assert actual_snapshots[new_snapshot.snapshot_id].intervals == [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), - ] - # The intervals shouldn't have been removed because restatement is disabled - assert actual_snapshots[snapshot.snapshot_id].intervals == [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), - ] - - def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None: from sqlmesh import __version__ as SQLMESH_VERSION @@ -2999,6 +2880,7 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot): 1, 1, False, + False, None, ], [ @@ -3011,6 +2893,7 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot): 1, 1, False, + False, None, ], ], @@ -3025,6 +2908,7 @@ def test_snapshot_batching(state_sync, mocker, make_snapshot): 1, 1, False, + False, None, ], ], From a1a285c94afc75ad32301973242b399074c1a12f Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Wed, 13 Aug 2025 09:47:15 -0700 Subject: [PATCH 2/3] unrestorable by (name, version) --- sqlmesh/core/state_sync/db/snapshot.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 4ea4a837fd..739c8ea1e2 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -113,12 +113,14 @@ def unpause_snapshots( snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike, ) -> None: - unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[str]] = defaultdict(list) + unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[SnapshotNameVersion]] = ( + defaultdict(list) + ) for snapshot in snapshots: # We need to mark all other snapshots that have opposite forward only status as unrestorable unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append( - snapshot.name + snapshot.name_version ) updated_ts = now_timestamp() @@ -143,11 +145,13 @@ def unpause_snapshots( ) # Mark unrestorable snapshots - for forward_only, snapshot_names in unrestorable_snapshots_by_forward_only.items(): + for forward_only, snapshot_name_versions in unrestorable_snapshots_by_forward_only.items(): forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only)) - for where in snapshot_name_filter( - snapshot_names, + for where in snapshot_name_version_filter( + self.engine_adapter, + snapshot_name_versions, batch_size=self.SNAPSHOT_BATCH_SIZE, + alias=None, ): self.engine_adapter.update_table( self.snapshots_table, From b7d67ec75326dc1d29b64b3effb762d1575db380 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Wed, 13 Aug 2025 10:13:59 -0700 Subject: [PATCH 3/3] improve comment --- sqlmesh/core/state_sync/db/snapshot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 739c8ea1e2..3745a27bb3 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -118,7 +118,7 @@ def unpause_snapshots( ) for snapshot in snapshots: - # We need to mark all other snapshots that have opposite forward only status as unrestorable + # We need to mark all other snapshots that have forward-only opposite to the target snapshot as unrestorable unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append( snapshot.name_version )