Skip to content

Commit 5aacaa8

Browse files
committed
Revert "Chore!: Optimize snapshot unpausing"
This reverts commit bed9f2e.
1 parent bed9f2e commit 5aacaa8

File tree

6 files changed

+209
-177
lines changed

6 files changed

+209
-177
lines changed

sqlmesh/core/state_sync/db/facade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def finalize(self, environment: Environment) -> None:
256256
def unpause_snapshots(
257257
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
258258
) -> None:
259-
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt)
259+
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state)
260260

261261
def invalidate_environment(self, name: str, protect_prod: bool = True) -> None:
262262
self.environment_state.invalidate_environment(name, protect_prod)

sqlmesh/core/state_sync/db/migrator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def _migrate_environment_rows(
396396
if updated_prod_environment:
397397
try:
398398
self.snapshot_state.unpause_snapshots(
399-
updated_prod_environment.snapshots, now_timestamp()
399+
updated_prod_environment.snapshots, now_timestamp(), self.interval_state
400400
)
401401
except Exception:
402402
logger.warning("Failed to unpause migrated snapshots", exc_info=True)

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 88 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from sqlmesh.core.engine_adapter import EngineAdapter
1212
from sqlmesh.core.state_sync.db.utils import (
13-
snapshot_name_filter,
1413
snapshot_name_version_filter,
1514
snapshot_id_filter,
1615
fetchone,
@@ -33,13 +32,15 @@
3332
SnapshotChangeCategory,
3433
)
3534
from sqlmesh.utils.migration import index_text_type, blob_text_type
36-
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
35+
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
3736
from sqlmesh.utils.pydantic import PydanticModel
3837
from sqlmesh.utils import unique
3938

4039
if t.TYPE_CHECKING:
4140
import pandas as pd
4241

42+
from sqlmesh.core.state_sync.db.interval import IntervalState
43+
4344

4445
logger = logging.getLogger(__name__)
4546

@@ -69,7 +70,6 @@ def __init__(
6970
"unpaused_ts": exp.DataType.build("bigint"),
7071
"ttl_ms": exp.DataType.build("bigint"),
7172
"unrestorable": exp.DataType.build("boolean"),
72-
"forward_only": exp.DataType.build("boolean"),
7373
}
7474

7575
self._auto_restatement_columns_to_types = {
@@ -112,48 +112,84 @@ def unpause_snapshots(
112112
self,
113113
snapshots: t.Collection[SnapshotInfoLike],
114114
unpaused_dt: TimeLike,
115+
interval_state: IntervalState,
115116
) -> None:
116-
unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[str]] = defaultdict(list)
117+
"""Unpauses given snapshots while pausing all other snapshots that share the same version.
117118
118-
for snapshot in snapshots:
119-
# We need to mark all other snapshots that have opposite forward only status as unrestorable
120-
unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append(
121-
snapshot.name
122-
)
119+
Args:
120+
snapshots: The snapshots to unpause.
121+
unpaused_dt: The timestamp to unpause the snapshots at.
122+
interval_state: The interval state to use to remove intervals when needed.
123+
"""
124+
current_ts = now()
123125

124-
updated_ts = now_timestamp()
125-
unpaused_ts = to_timestamp(unpaused_dt)
126+
target_snapshot_ids = {s.snapshot_id for s in snapshots}
127+
same_version_snapshots = self._get_snapshots_with_same_version(
128+
snapshots, lock_for_update=True
129+
)
130+
target_snapshots_by_version = {
131+
(s.name, s.version): s
132+
for s in same_version_snapshots
133+
if s.snapshot_id in target_snapshot_ids
134+
}
126135

127-
# Pause all snapshots with target names first
128-
for where in snapshot_name_filter(
129-
[s.name for s in snapshots],
130-
batch_size=self.SNAPSHOT_BATCH_SIZE,
131-
):
132-
self.engine_adapter.update_table(
133-
self.snapshots_table,
134-
{"unpaused_ts": None, "updated_ts": updated_ts},
135-
where=where,
136-
)
136+
unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list)
137+
paused_snapshots: t.List[SnapshotId] = []
138+
unrestorable_snapshots: t.List[SnapshotId] = []
139+
140+
for snapshot in same_version_snapshots:
141+
is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids
142+
if is_target_snapshot and not snapshot.unpaused_ts:
143+
logger.info("Unpausing snapshot %s", snapshot.snapshot_id)
144+
snapshot.set_unpaused_ts(unpaused_dt)
145+
assert snapshot.unpaused_ts is not None
146+
unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id)
147+
elif not is_target_snapshot:
148+
target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)]
149+
if (
150+
target_snapshot.normalized_effective_from_ts
151+
and not target_snapshot.disable_restatement
152+
):
153+
# Making sure that there are no overlapping intervals.
154+
effective_from_ts = target_snapshot.normalized_effective_from_ts
155+
logger.info(
156+
"Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s",
157+
target_snapshot.effective_from,
158+
snapshot.snapshot_id,
159+
target_snapshot.snapshot_id,
160+
)
161+
full_snapshot = snapshot.full_snapshot
162+
interval_state.remove_intervals(
163+
[
164+
(
165+
full_snapshot,
166+
full_snapshot.get_removal_interval(effective_from_ts, current_ts),
167+
)
168+
]
169+
)
137170

138-
# Now unpause the target snapshots
139-
self._update_snapshots(
140-
[s.snapshot_id for s in snapshots],
141-
unpaused_ts=unpaused_ts,
142-
updated_ts=updated_ts,
143-
)
171+
if snapshot.unpaused_ts:
172+
logger.info("Pausing snapshot %s", snapshot.snapshot_id)
173+
snapshot.set_unpaused_ts(None)
174+
paused_snapshots.append(snapshot.snapshot_id)
144175

145-
# Mark unrestorable snapshots
146-
for forward_only, snapshot_names in unrestorable_snapshots_by_forward_only.items():
147-
forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only))
148-
for where in snapshot_name_filter(
149-
snapshot_names,
150-
batch_size=self.SNAPSHOT_BATCH_SIZE,
151-
):
152-
self.engine_adapter.update_table(
153-
self.snapshots_table,
154-
{"unrestorable": True, "updated_ts": updated_ts},
155-
where=forward_only_exp.and_(where),
156-
)
176+
if not snapshot.unrestorable and (
177+
(target_snapshot.is_forward_only and not snapshot.is_forward_only)
178+
or (snapshot.is_forward_only and not target_snapshot.is_forward_only)
179+
):
180+
logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id)
181+
snapshot.unrestorable = True
182+
unrestorable_snapshots.append(snapshot.snapshot_id)
183+
184+
if unpaused_snapshots:
185+
for unpaused_ts, snapshot_ids in unpaused_snapshots.items():
186+
self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts)
187+
188+
if paused_snapshots:
189+
self._update_snapshots(paused_snapshots, unpaused_ts=None)
190+
191+
if unrestorable_snapshots:
192+
self._update_snapshots(unrestorable_snapshots, unrestorable=True)
157193

158194
def get_expired_snapshots(
159195
self,
@@ -378,8 +414,7 @@ def _update_snapshots(
378414
**kwargs: t.Any,
379415
) -> None:
380416
properties = kwargs
381-
if "updated_ts" not in properties:
382-
properties["updated_ts"] = now_timestamp()
417+
properties["updated_ts"] = now_timestamp()
383418

384419
for where in snapshot_id_filter(
385420
self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE
@@ -431,15 +466,13 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
431466
updated_ts,
432467
unpaused_ts,
433468
unrestorable,
434-
forward_only,
435469
next_auto_restatement_ts,
436470
) in fetchall(self.engine_adapter, query):
437471
snapshot = parse_snapshot(
438472
serialized_snapshot=serialized_snapshot,
439473
updated_ts=updated_ts,
440474
unpaused_ts=unpaused_ts,
441475
unrestorable=unrestorable,
442-
forward_only=forward_only,
443476
next_auto_restatement_ts=next_auto_restatement_ts,
444477
)
445478
snapshot_id = snapshot.snapshot_id
@@ -469,7 +502,6 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
469502
"updated_ts",
470503
"unpaused_ts",
471504
"unrestorable",
472-
"forward_only",
473505
"next_auto_restatement_ts",
474506
)
475507
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
@@ -496,15 +528,13 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
496528
updated_ts,
497529
unpaused_ts,
498530
unrestorable,
499-
forward_only,
500531
next_auto_restatement_ts,
501532
) in fetchall(self.engine_adapter, query):
502533
snapshot_id = SnapshotId(name=name, identifier=identifier)
503534
snapshot = snapshots[snapshot_id]
504535
snapshot.updated_ts = updated_ts
505536
snapshot.unpaused_ts = unpaused_ts
506537
snapshot.unrestorable = unrestorable
507-
snapshot.forward_only = forward_only
508538
snapshot.next_auto_restatement_ts = next_auto_restatement_ts
509539
cached_snapshots_in_state.add(snapshot_id)
510540

@@ -538,7 +568,6 @@ def _get_snapshots_expressions(
538568
"snapshots.updated_ts",
539569
"snapshots.unpaused_ts",
540570
"snapshots.unrestorable",
541-
"snapshots.forward_only",
542571
"auto_restatements.next_auto_restatement_ts",
543572
)
544573
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
@@ -594,7 +623,6 @@ def _get_snapshots_with_same_version(
594623
"updated_ts",
595624
"unpaused_ts",
596625
"unrestorable",
597-
"forward_only",
598626
)
599627
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
600628
.where(where)
@@ -612,10 +640,9 @@ def _get_snapshots_with_same_version(
612640
updated_ts=updated_ts,
613641
unpaused_ts=unpaused_ts,
614642
unrestorable=unrestorable,
615-
forward_only=forward_only,
616643
snapshot=snapshot,
617644
)
618-
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows
645+
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows
619646
]
620647

621648

@@ -624,7 +651,6 @@ def parse_snapshot(
624651
updated_ts: int,
625652
unpaused_ts: t.Optional[int],
626653
unrestorable: bool,
627-
forward_only: bool,
628654
next_auto_restatement_ts: t.Optional[int],
629655
) -> Snapshot:
630656
return Snapshot(
@@ -633,7 +659,6 @@ def parse_snapshot(
633659
"updated_ts": updated_ts,
634660
"unpaused_ts": unpaused_ts,
635661
"unrestorable": unrestorable,
636-
"forward_only": forward_only,
637662
"next_auto_restatement_ts": next_auto_restatement_ts,
638663
}
639664
)
@@ -648,7 +673,6 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
648673
"updated_ts",
649674
"unpaused_ts",
650675
"unrestorable",
651-
"forward_only",
652676
"next_auto_restatement_ts",
653677
}
654678
)
@@ -669,7 +693,6 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
669693
"unpaused_ts": snapshot.unpaused_ts,
670694
"ttl_ms": snapshot.ttl_ms,
671695
"unrestorable": snapshot.unrestorable,
672-
"forward_only": snapshot.forward_only,
673696
}
674697
for snapshot in snapshots
675698
]
@@ -739,10 +762,19 @@ def full_snapshot(self) -> Snapshot:
739762
"updated_ts": self.updated_ts,
740763
"unpaused_ts": self.unpaused_ts,
741764
"unrestorable": self.unrestorable,
742-
"forward_only": self.forward_only,
743765
}
744766
)
745767

768+
def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None:
769+
"""Sets the timestamp for when this snapshot was unpaused.
770+
771+
Args:
772+
unpaused_dt: The datetime object of when this snapshot was unpaused.
773+
"""
774+
self.unpaused_ts = (
775+
to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None
776+
)
777+
746778
@classmethod
747779
def from_snapshot_record(
748780
cls,
@@ -753,7 +785,6 @@ def from_snapshot_record(
753785
updated_ts: int,
754786
unpaused_ts: t.Optional[int],
755787
unrestorable: bool,
756-
forward_only: bool,
757788
snapshot: str,
758789
) -> SharedVersionSnapshot:
759790
raw_snapshot = json.loads(snapshot)
@@ -772,5 +803,5 @@ def from_snapshot_record(
772803
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
773804
effective_from=raw_snapshot.get("effective_from"),
774805
raw_snapshot=raw_snapshot,
775-
forward_only=forward_only,
806+
forward_only=raw_snapshot.get("forward_only", False),
776807
)

sqlmesh/core/state_sync/db/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,6 @@
2222
T = t.TypeVar("T")
2323

2424

25-
def snapshot_name_filter(
26-
snapshot_names: t.Iterable[str],
27-
batch_size: int,
28-
alias: t.Optional[str] = None,
29-
) -> t.Iterator[exp.Condition]:
30-
names = sorted(snapshot_names)
31-
32-
if not names:
33-
yield exp.false()
34-
else:
35-
batches = create_batches(names, batch_size=batch_size)
36-
for names in batches:
37-
yield exp.column("name", table=alias).isin(*names)
38-
39-
4025
def snapshot_id_filter(
4126
engine_adapter: EngineAdapter,
4227
snapshot_ids: t.Iterable[SnapshotIdLike],

0 commit comments

Comments
 (0)