Skip to content

Commit ded6795

Browse files
authored
Chore!: Optimize snapshot unpausing (#5147)
1 parent 2897b4e commit ded6795

File tree

6 files changed

+180
-208
lines changed

6 files changed

+180
-208
lines changed

sqlmesh/core/state_sync/db/facade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def finalize(self, environment: Environment) -> None:
256256
def unpause_snapshots(
257257
self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike
258258
) -> None:
259-
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt, self.interval_state)
259+
self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt)
260260

261261
def invalidate_environment(self, name: str, protect_prod: bool = True) -> None:
262262
self.environment_state.invalidate_environment(name, protect_prod)

sqlmesh/core/state_sync/db/migrator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def _migrate_environment_rows(
396396
if updated_prod_environment:
397397
try:
398398
self.snapshot_state.unpause_snapshots(
399-
updated_prod_environment.snapshots, now_timestamp(), self.interval_state
399+
updated_prod_environment.snapshots, now_timestamp()
400400
)
401401
except Exception:
402402
logger.warning("Failed to unpause migrated snapshots", exc_info=True)

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 60 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from sqlmesh.core.engine_adapter import EngineAdapter
1212
from sqlmesh.core.state_sync.db.utils import (
13+
snapshot_name_filter,
1314
snapshot_name_version_filter,
1415
snapshot_id_filter,
1516
fetchone,
@@ -32,15 +33,13 @@
3233
SnapshotChangeCategory,
3334
)
3435
from sqlmesh.utils.migration import index_text_type, blob_text_type
35-
from sqlmesh.utils.date import now_timestamp, TimeLike, now, to_timestamp
36+
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
3637
from sqlmesh.utils.pydantic import PydanticModel
3738
from sqlmesh.utils import unique
3839

3940
if t.TYPE_CHECKING:
4041
import pandas as pd
4142

42-
from sqlmesh.core.state_sync.db.interval import IntervalState
43-
4443

4544
logger = logging.getLogger(__name__)
4645

@@ -70,6 +69,7 @@ def __init__(
7069
"unpaused_ts": exp.DataType.build("bigint"),
7170
"ttl_ms": exp.DataType.build("bigint"),
7271
"unrestorable": exp.DataType.build("boolean"),
72+
"forward_only": exp.DataType.build("boolean"),
7373
}
7474

7575
self._auto_restatement_columns_to_types = {
@@ -112,84 +112,52 @@ def unpause_snapshots(
112112
self,
113113
snapshots: t.Collection[SnapshotInfoLike],
114114
unpaused_dt: TimeLike,
115-
interval_state: IntervalState,
116115
) -> None:
117-
"""Unpauses given snapshots while pausing all other snapshots that share the same version.
118-
119-
Args:
120-
snapshots: The snapshots to unpause.
121-
unpaused_dt: The timestamp to unpause the snapshots at.
122-
interval_state: The interval state to use to remove intervals when needed.
123-
"""
124-
current_ts = now()
125-
126-
target_snapshot_ids = {s.snapshot_id for s in snapshots}
127-
same_version_snapshots = self._get_snapshots_with_same_version(
128-
snapshots, lock_for_update=True
116+
unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[SnapshotNameVersion]] = (
117+
defaultdict(list)
129118
)
130-
target_snapshots_by_version = {
131-
(s.name, s.version): s
132-
for s in same_version_snapshots
133-
if s.snapshot_id in target_snapshot_ids
134-
}
135-
136-
unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list)
137-
paused_snapshots: t.List[SnapshotId] = []
138-
unrestorable_snapshots: t.List[SnapshotId] = []
139-
140-
for snapshot in same_version_snapshots:
141-
is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids
142-
if is_target_snapshot and not snapshot.unpaused_ts:
143-
logger.info("Unpausing snapshot %s", snapshot.snapshot_id)
144-
snapshot.set_unpaused_ts(unpaused_dt)
145-
assert snapshot.unpaused_ts is not None
146-
unpaused_snapshots[snapshot.unpaused_ts].append(snapshot.snapshot_id)
147-
elif not is_target_snapshot:
148-
target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)]
149-
if (
150-
target_snapshot.normalized_effective_from_ts
151-
and not target_snapshot.disable_restatement
152-
):
153-
# Making sure that there are no overlapping intervals.
154-
effective_from_ts = target_snapshot.normalized_effective_from_ts
155-
logger.info(
156-
"Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s",
157-
target_snapshot.effective_from,
158-
snapshot.snapshot_id,
159-
target_snapshot.snapshot_id,
160-
)
161-
full_snapshot = snapshot.full_snapshot
162-
interval_state.remove_intervals(
163-
[
164-
(
165-
full_snapshot,
166-
full_snapshot.get_removal_interval(effective_from_ts, current_ts),
167-
)
168-
]
169-
)
170119

171-
if snapshot.unpaused_ts:
172-
logger.info("Pausing snapshot %s", snapshot.snapshot_id)
173-
snapshot.set_unpaused_ts(None)
174-
paused_snapshots.append(snapshot.snapshot_id)
120+
for snapshot in snapshots:
121+
# We need to mark all other snapshots that have forward-only opposite to the target snapshot as unrestorable
122+
unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append(
123+
snapshot.name_version
124+
)
175125

176-
if not snapshot.unrestorable and (
177-
(target_snapshot.is_forward_only and not snapshot.is_forward_only)
178-
or (snapshot.is_forward_only and not target_snapshot.is_forward_only)
179-
):
180-
logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id)
181-
snapshot.unrestorable = True
182-
unrestorable_snapshots.append(snapshot.snapshot_id)
126+
updated_ts = now_timestamp()
127+
unpaused_ts = to_timestamp(unpaused_dt)
183128

184-
if unpaused_snapshots:
185-
for unpaused_ts, snapshot_ids in unpaused_snapshots.items():
186-
self._update_snapshots(snapshot_ids, unpaused_ts=unpaused_ts)
129+
# Pause all snapshots with target names first
130+
for where in snapshot_name_filter(
131+
[s.name for s in snapshots],
132+
batch_size=self.SNAPSHOT_BATCH_SIZE,
133+
):
134+
self.engine_adapter.update_table(
135+
self.snapshots_table,
136+
{"unpaused_ts": None, "updated_ts": updated_ts},
137+
where=where,
138+
)
187139

188-
if paused_snapshots:
189-
self._update_snapshots(paused_snapshots, unpaused_ts=None)
140+
# Now unpause the target snapshots
141+
self._update_snapshots(
142+
[s.snapshot_id for s in snapshots],
143+
unpaused_ts=unpaused_ts,
144+
updated_ts=updated_ts,
145+
)
190146

191-
if unrestorable_snapshots:
192-
self._update_snapshots(unrestorable_snapshots, unrestorable=True)
147+
# Mark unrestorable snapshots
148+
for forward_only, snapshot_name_versions in unrestorable_snapshots_by_forward_only.items():
149+
forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only))
150+
for where in snapshot_name_version_filter(
151+
self.engine_adapter,
152+
snapshot_name_versions,
153+
batch_size=self.SNAPSHOT_BATCH_SIZE,
154+
alias=None,
155+
):
156+
self.engine_adapter.update_table(
157+
self.snapshots_table,
158+
{"unrestorable": True, "updated_ts": updated_ts},
159+
where=forward_only_exp.and_(where),
160+
)
193161

194162
def get_expired_snapshots(
195163
self,
@@ -414,7 +382,8 @@ def _update_snapshots(
414382
**kwargs: t.Any,
415383
) -> None:
416384
properties = kwargs
417-
properties["updated_ts"] = now_timestamp()
385+
if "updated_ts" not in properties:
386+
properties["updated_ts"] = now_timestamp()
418387

419388
for where in snapshot_id_filter(
420389
self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE
@@ -466,13 +435,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
466435
updated_ts,
467436
unpaused_ts,
468437
unrestorable,
438+
forward_only,
469439
next_auto_restatement_ts,
470440
) in fetchall(self.engine_adapter, query):
471441
snapshot = parse_snapshot(
472442
serialized_snapshot=serialized_snapshot,
473443
updated_ts=updated_ts,
474444
unpaused_ts=unpaused_ts,
475445
unrestorable=unrestorable,
446+
forward_only=forward_only,
476447
next_auto_restatement_ts=next_auto_restatement_ts,
477448
)
478449
snapshot_id = snapshot.snapshot_id
@@ -502,6 +473,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
502473
"updated_ts",
503474
"unpaused_ts",
504475
"unrestorable",
476+
"forward_only",
505477
"next_auto_restatement_ts",
506478
)
507479
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
@@ -528,13 +500,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
528500
updated_ts,
529501
unpaused_ts,
530502
unrestorable,
503+
forward_only,
531504
next_auto_restatement_ts,
532505
) in fetchall(self.engine_adapter, query):
533506
snapshot_id = SnapshotId(name=name, identifier=identifier)
534507
snapshot = snapshots[snapshot_id]
535508
snapshot.updated_ts = updated_ts
536509
snapshot.unpaused_ts = unpaused_ts
537510
snapshot.unrestorable = unrestorable
511+
snapshot.forward_only = forward_only
538512
snapshot.next_auto_restatement_ts = next_auto_restatement_ts
539513
cached_snapshots_in_state.add(snapshot_id)
540514

@@ -568,6 +542,7 @@ def _get_snapshots_expressions(
568542
"snapshots.updated_ts",
569543
"snapshots.unpaused_ts",
570544
"snapshots.unrestorable",
545+
"snapshots.forward_only",
571546
"auto_restatements.next_auto_restatement_ts",
572547
)
573548
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
@@ -623,6 +598,7 @@ def _get_snapshots_with_same_version(
623598
"updated_ts",
624599
"unpaused_ts",
625600
"unrestorable",
601+
"forward_only",
626602
)
627603
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
628604
.where(where)
@@ -640,9 +616,10 @@ def _get_snapshots_with_same_version(
640616
updated_ts=updated_ts,
641617
unpaused_ts=unpaused_ts,
642618
unrestorable=unrestorable,
619+
forward_only=forward_only,
643620
snapshot=snapshot,
644621
)
645-
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows
622+
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows
646623
]
647624

648625

@@ -651,6 +628,7 @@ def parse_snapshot(
651628
updated_ts: int,
652629
unpaused_ts: t.Optional[int],
653630
unrestorable: bool,
631+
forward_only: bool,
654632
next_auto_restatement_ts: t.Optional[int],
655633
) -> Snapshot:
656634
return Snapshot(
@@ -659,6 +637,7 @@ def parse_snapshot(
659637
"updated_ts": updated_ts,
660638
"unpaused_ts": unpaused_ts,
661639
"unrestorable": unrestorable,
640+
"forward_only": forward_only,
662641
"next_auto_restatement_ts": next_auto_restatement_ts,
663642
}
664643
)
@@ -673,6 +652,7 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
673652
"updated_ts",
674653
"unpaused_ts",
675654
"unrestorable",
655+
"forward_only",
676656
"next_auto_restatement_ts",
677657
}
678658
)
@@ -693,6 +673,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
693673
"unpaused_ts": snapshot.unpaused_ts,
694674
"ttl_ms": snapshot.ttl_ms,
695675
"unrestorable": snapshot.unrestorable,
676+
"forward_only": snapshot.forward_only,
696677
}
697678
for snapshot in snapshots
698679
]
@@ -762,19 +743,10 @@ def full_snapshot(self) -> Snapshot:
762743
"updated_ts": self.updated_ts,
763744
"unpaused_ts": self.unpaused_ts,
764745
"unrestorable": self.unrestorable,
746+
"forward_only": self.forward_only,
765747
}
766748
)
767749

768-
def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None:
769-
"""Sets the timestamp for when this snapshot was unpaused.
770-
771-
Args:
772-
unpaused_dt: The datetime object of when this snapshot was unpaused.
773-
"""
774-
self.unpaused_ts = (
775-
to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None
776-
)
777-
778750
@classmethod
779751
def from_snapshot_record(
780752
cls,
@@ -785,6 +757,7 @@ def from_snapshot_record(
785757
updated_ts: int,
786758
unpaused_ts: t.Optional[int],
787759
unrestorable: bool,
760+
forward_only: bool,
788761
snapshot: str,
789762
) -> SharedVersionSnapshot:
790763
raw_snapshot = json.loads(snapshot)
@@ -803,5 +776,5 @@ def from_snapshot_record(
803776
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
804777
effective_from=raw_snapshot.get("effective_from"),
805778
raw_snapshot=raw_snapshot,
806-
forward_only=raw_snapshot.get("forward_only", False),
779+
forward_only=forward_only,
807780
)

sqlmesh/core/state_sync/db/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@
2222
T = t.TypeVar("T")
2323

2424

25+
def snapshot_name_filter(
26+
snapshot_names: t.Iterable[str],
27+
batch_size: int,
28+
alias: t.Optional[str] = None,
29+
) -> t.Iterator[exp.Condition]:
30+
names = sorted(snapshot_names)
31+
32+
if not names:
33+
yield exp.false()
34+
else:
35+
batches = create_batches(names, batch_size=batch_size)
36+
for names in batches:
37+
yield exp.column("name", table=alias).isin(*names)
38+
39+
2540
def snapshot_id_filter(
2641
engine_adapter: EngineAdapter,
2742
snapshot_ids: t.Iterable[SnapshotIdLike],

0 commit comments

Comments
 (0)