Skip to content

Commit 050d94e

Browse files
committed
Chore!: Optimize snapshot unpausing
1 parent 5aacaa8 commit 050d94e

File tree

6 files changed

+177
-209
lines changed

6 files changed

+177
-209
lines changed

sqlmesh/core/state_sync/db/facade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def finalize(self, environment: Environment) -> None:
256256
def unpause_snapshots(
257257
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
258258
) -> None:
259-
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state)
259+
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt)
260260

261261
def invalidate_environment(self, name: str, protect_prod: bool = True) -> None:
262262
self.environment_state.invalidate_environment(name, protect_prod)

sqlmesh/core/state_sync/db/migrator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def _migrate_environment_rows(
396396
if updated_prod_environment:
397397
try:
398398
self.snapshot_state.unpause_snapshots(
399-
updated_prod_environment.snapshots, now_timestamp(), self.interval_state
399+
updated_prod_environment.snapshots, now_timestamp()
400400
)
401401
except Exception:
402402
logger.warning("Failed to unpause migrated snapshots", exc_info=True)

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 57 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from sqlmesh.core.engine_adapter import EngineAdapter
1212
from sqlmesh.core.state_sync.db.utils import (
13+
snapshot_name_filter,
1314
snapshot_name_version_filter,
1415
snapshot_id_filter,
1516
fetchone,
@@ -32,15 +33,13 @@
3233
SnapshotChangeCategory,
3334
)
3435
from sqlmesh.utils.migration import index_text_type, blob_text_type
35-
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
36+
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
3637
from sqlmesh.utils.pydantic import PydanticModel
3738
from sqlmesh.utils import unique
3839

3940
if t.TYPE_CHECKING:
4041
import pandas as pd
4142

42-
from sqlmesh.core.state_sync.db.interval import IntervalState
43-
4443

4544
logger = logging.getLogger(__name__)
4645

@@ -70,6 +69,7 @@ def __init__(
7069
"unpaused_ts": exp.DataType.build("bigint"),
7170
"ttl_ms": exp.DataType.build("bigint"),
7271
"unrestorable": exp.DataType.build("boolean"),
72+
"forward_only": exp.DataType.build("boolean"),
7373
}
7474

7575
self._auto_restatement_columns_to_types = {
@@ -112,84 +112,48 @@ def unpause_snapshots(
112112
self,
113113
snapshots: t.Collection[SnapshotInfoLike],
114114
unpaused_dt: TimeLike,
115-
interval_state: IntervalState,
116115
) -> None:
117-
"""Unpauses given snapshots while pausing all other snapshots that share the same version.
118-
119-
Args:
120-
snapshots: The snapshots to unpause.
121-
unpaused_dt: The timestamp to unpause the snapshots at.
122-
interval_state: The interval state to use to remove intervals when needed.
123-
"""
124-
current_ts = now()
125-
126-
target_snapshot_ids = {s.snapshot_id for s in snapshots}
127-
same_version_snapshots = self._get_snapshots_with_same_version(
128-
snapshots, lock_for_update=True
129-
)
130-
target_snapshots_by_version = {
131-
(s.name, s.version): s
132-
for s in same_version_snapshots
133-
if s.snapshot_id in target_snapshot_ids
134-
}
116+
unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[str]] = defaultdict(list)
135117

136-
unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list)
137-
paused_snapshots: t.List[SnapshotId] = []
138-
unrestorable_snapshots: t.List[SnapshotId] = []
139-
140-
for snapshot in same_version_snapshots:
141-
is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids
142-
if is_target_snapshot and not snapshot.unpaused_ts:
143-
logger.info("Unpausing snapshot %s", snapshot.snapshot_id)
144-
snapshot.set_unpaused_ts(unpaused_dt)
145-
assert snapshot.unpaused_ts is not None
146-
unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id)
147-
elif not is_target_snapshot:
148-
target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)]
149-
if (
150-
target_snapshot.normalized_effective_from_ts
151-
and not target_snapshot.disable_restatement
152-
):
153-
# Making sure that there are no overlapping intervals.
154-
effective_from_ts = target_snapshot.normalized_effective_from_ts
155-
logger.info(
156-
"Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s",
157-
target_snapshot.effective_from,
158-
snapshot.snapshot_id,
159-
target_snapshot.snapshot_id,
160-
)
161-
full_snapshot = snapshot.full_snapshot
162-
interval_state.remove_intervals(
163-
[
164-
(
165-
full_snapshot,
166-
full_snapshot.get_removal_interval(effective_from_ts, current_ts),
167-
)
168-
]
169-
)
170-
171-
if snapshot.unpaused_ts:
172-
logger.info("Pausing snapshot %s", snapshot.snapshot_id)
173-
snapshot.set_unpaused_ts(None)
174-
paused_snapshots.append(snapshot.snapshot_id)
118+
for snapshot in snapshots:
119+
# We need to mark all other snapshots that have opposite forward only status as unrestorable
120+
unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append(
121+
snapshot.name
122+
)
175123

176-
if not snapshot.unrestorable and (
177-
(target_snapshot.is_forward_only and not snapshot.is_forward_only)
178-
or (snapshot.is_forward_only and not target_snapshot.is_forward_only)
179-
):
180-
logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id)
181-
snapshot.unrestorable = True
182-
unrestorable_snapshots.append(snapshot.snapshot_id)
124+
updated_ts = now_timestamp()
125+
unpaused_ts = to_timestamp(unpaused_dt)
183126

184-
if unpaused_snapshots:
185-
for unpaused_ts, snapshot_ids in unpaused_snapshots.items():
186-
self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts)
127+
# Pause all snapshots with target names first
128+
for where in snapshot_name_filter(
129+
[s.name for s in snapshots],
130+
batch_size=self.SNAPSHOT_BATCH_SIZE,
131+
):
132+
self.engine_adapter.update_table(
133+
self.snapshots_table,
134+
{"unpaused_ts": None, "updated_ts": updated_ts},
135+
where=where,
136+
)
187137

188-
if paused_snapshots:
189-
self._update_snapshots(paused_snapshots, unpaused_ts=None)
138+
# Now unpause the target snapshots
139+
self._update_snapshots(
140+
[s.snapshot_id for s in snapshots],
141+
unpaused_ts=unpaused_ts,
142+
updated_ts=updated_ts,
143+
)
190144

191-
if unrestorable_snapshots:
192-
self._update_snapshots(unrestorable_snapshots, unrestorable=True)
145+
# Mark unrestorable snapshots
146+
for forward_only, snapshot_names in unrestorable_snapshots_by_forward_only.items():
147+
forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only))
148+
for where in snapshot_name_filter(
149+
snapshot_names,
150+
batch_size=self.SNAPSHOT_BATCH_SIZE,
151+
):
152+
self.engine_adapter.update_table(
153+
self.snapshots_table,
154+
{"unrestorable": True, "updated_ts": updated_ts},
155+
where=forward_only_exp.and_(where),
156+
)
193157

194158
def get_expired_snapshots(
195159
self,
@@ -414,7 +378,8 @@ def _update_snapshots(
414378
**kwargs: t.Any,
415379
) -> None:
416380
properties = kwargs
417-
properties["updated_ts"] = now_timestamp()
381+
if "updated_ts" not in properties:
382+
properties["updated_ts"] = now_timestamp()
418383

419384
for where in snapshot_id_filter(
420385
self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE
@@ -466,13 +431,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
466431
updated_ts,
467432
unpaused_ts,
468433
unrestorable,
434+
forward_only,
469435
next_auto_restatement_ts,
470436
) in fetchall(self.engine_adapter, query):
471437
snapshot = parse_snapshot(
472438
serialized_snapshot=serialized_snapshot,
473439
updated_ts=updated_ts,
474440
unpaused_ts=unpaused_ts,
475441
unrestorable=unrestorable,
442+
forward_only=forward_only,
476443
next_auto_restatement_ts=next_auto_restatement_ts,
477444
)
478445
snapshot_id = snapshot.snapshot_id
@@ -502,6 +469,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
502469
"updated_ts",
503470
"unpaused_ts",
504471
"unrestorable",
472+
"forward_only",
505473
"next_auto_restatement_ts",
506474
)
507475
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
@@ -528,13 +496,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
528496
updated_ts,
529497
unpaused_ts,
530498
unrestorable,
499+
forward_only,
531500
next_auto_restatement_ts,
532501
) in fetchall(self.engine_adapter, query):
533502
snapshot_id = SnapshotId(name=name, identifier=identifier)
534503
snapshot = snapshots[snapshot_id]
535504
snapshot.updated_ts = updated_ts
536505
snapshot.unpaused_ts = unpaused_ts
537506
snapshot.unrestorable = unrestorable
507+
snapshot.forward_only = forward_only
538508
snapshot.next_auto_restatement_ts = next_auto_restatement_ts
539509
cached_snapshots_in_state.add(snapshot_id)
540510

@@ -568,6 +538,7 @@ def _get_snapshots_expressions(
568538
"snapshots.updated_ts",
569539
"snapshots.unpaused_ts",
570540
"snapshots.unrestorable",
541+
"snapshots.forward_only",
571542
"auto_restatements.next_auto_restatement_ts",
572543
)
573544
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
@@ -623,6 +594,7 @@ def _get_snapshots_with_same_version(
623594
"updated_ts",
624595
"unpaused_ts",
625596
"unrestorable",
597+
"forward_only",
626598
)
627599
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
628600
.where(where)
@@ -640,9 +612,10 @@ def _get_snapshots_with_same_version(
640612
updated_ts=updated_ts,
641613
unpaused_ts=unpaused_ts,
642614
unrestorable=unrestorable,
615+
forward_only=forward_only,
643616
snapshot=snapshot,
644617
)
645-
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows
618+
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows
646619
]
647620

648621

@@ -651,6 +624,7 @@ def parse_snapshot(
651624
updated_ts: int,
652625
unpaused_ts: t.Optional[int],
653626
unrestorable: bool,
627+
forward_only: bool,
654628
next_auto_restatement_ts: t.Optional[int],
655629
) -> Snapshot:
656630
return Snapshot(
@@ -659,6 +633,7 @@ def parse_snapshot(
659633
"updated_ts": updated_ts,
660634
"unpaused_ts": unpaused_ts,
661635
"unrestorable": unrestorable,
636+
"forward_only": forward_only,
662637
"next_auto_restatement_ts": next_auto_restatement_ts,
663638
}
664639
)
@@ -673,6 +648,7 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
673648
"updated_ts",
674649
"unpaused_ts",
675650
"unrestorable",
651+
"forward_only",
676652
"next_auto_restatement_ts",
677653
}
678654
)
@@ -693,6 +669,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
693669
"unpaused_ts": snapshot.unpaused_ts,
694670
"ttl_ms": snapshot.ttl_ms,
695671
"unrestorable": snapshot.unrestorable,
672+
"forward_only": snapshot.forward_only,
696673
}
697674
for snapshot in snapshots
698675
]
@@ -762,19 +739,10 @@ def full_snapshot(self) -> Snapshot:
762739
"updated_ts": self.updated_ts,
763740
"unpaused_ts": self.unpaused_ts,
764741
"unrestorable": self.unrestorable,
742+
"forward_only": self.forward_only,
765743
}
766744
)
767745

768-
def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None:
769-
"""Sets the timestamp for when this snapshot was unpaused.
770-
771-
Args:
772-
unpaused_dt: The datetime object of when this snapshot was unpaused.
773-
"""
774-
self.unpaused_ts = (
775-
to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None
776-
)
777-
778746
@classmethod
779747
def from_snapshot_record(
780748
cls,
@@ -785,6 +753,7 @@ def from_snapshot_record(
785753
updated_ts: int,
786754
unpaused_ts: t.Optional[int],
787755
unrestorable: bool,
756+
forward_only: bool,
788757
snapshot: str,
789758
) -> SharedVersionSnapshot:
790759
raw_snapshot = json.loads(snapshot)
@@ -803,5 +772,5 @@ def from_snapshot_record(
803772
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
804773
effective_from=raw_snapshot.get("effective_from"),
805774
raw_snapshot=raw_snapshot,
806-
forward_only=raw_snapshot.get("forward_only", False),
775+
forward_only=forward_only,
807776
)

sqlmesh/core/state_sync/db/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@
2222
T = t.TypeVar("T")
2323

2424

25+
def snapshot_name_filter(
26+
snapshot_names: t.Iterable[str],
27+
batch_size: int,
28+
alias: t.Optional[str] = None,
29+
) -> t.Iterator[exp.Condition]:
30+
names = sorted(snapshot_names)
31+
32+
if not names:
33+
yield exp.false()
34+
else:
35+
batches = create_batches(names, batch_size=batch_size)
36+
for names in batches:
37+
yield exp.column("name", table=alias).isin(*names)
38+
39+
2540
def snapshot_id_filter(
2641
engine_adapter: EngineAdapter,
2742
snapshot_ids: t.Iterable[SnapshotIdLike],

0 commit comments

Comments
 (0)