Skip to content

Commit 295d37a

Browse files
authored
Fix: Reduce parsing overhead and improve performance when unpausing snapshots deployed to prod (#3627)
1 parent 3686624 commit 295d37a

File tree

4 files changed

+217
-32
lines changed

4 files changed

+217
-32
lines changed

sqlmesh/core/snapshot/definition.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -940,16 +940,6 @@ def categorize_as(self, category: SnapshotChangeCategory) -> None:
940940

941941
self.change_category = category
942942

943-
def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None:
944-
"""Sets the timestamp for when this snapshot was unpaused.
945-
946-
Args:
947-
unpaused_dt: The datetime object of when this snapshot was unpaused.
948-
"""
949-
self.unpaused_ts = (
950-
to_timestamp(self.node.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None
951-
)
952-
953943
def table_name(self, is_deployable: bool = True) -> str:
954944
"""Full table name pointing to the materialized location of the snapshot.
955945

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 133 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from datetime import datetime
2828

2929
import pandas as pd
30+
from pydantic import Field
3031
from sqlglot import __version__ as SQLGLOT_VERSION
3132
from sqlglot import exp
3233
from sqlglot.helper import seq_get
@@ -37,10 +38,12 @@
3738
from sqlmesh.core.engine_adapter import EngineAdapter
3839
from sqlmesh.core.environment import Environment
3940
from sqlmesh.core.model import ModelKindName, SeedModel
41+
from sqlmesh.core.node import IntervalUnit
4042
from sqlmesh.core.snapshot import (
4143
Intervals,
4244
Node,
4345
Snapshot,
46+
SnapshotChangeCategory,
4447
SnapshotFingerprint,
4548
SnapshotId,
4649
SnapshotIdLike,
@@ -73,6 +76,7 @@
7376
from sqlmesh.utils.date import TimeLike, now, now_timestamp, time_like_to_str, to_timestamp
7477
from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError
7578
from sqlmesh.utils.migration import blob_text_type, index_text_type
79+
from sqlmesh.utils.pydantic import PydanticModel
7680

7781
logger = logging.getLogger(__name__)
7882

@@ -434,16 +438,20 @@ def unpause_snapshots(
434438
current_ts = now()
435439

436440
target_snapshot_ids = {s.snapshot_id for s in snapshots}
437-
snapshots = self._get_snapshots_with_same_version(snapshots, lock_for_update=True)
441+
same_version_snapshots = self._get_snapshots_with_same_version(
442+
snapshots, lock_for_update=True
443+
)
438444
target_snapshots_by_version = {
439-
(s.name, s.version): s for s in snapshots if s.snapshot_id in target_snapshot_ids
445+
(s.name, s.version): s
446+
for s in same_version_snapshots
447+
if s.snapshot_id in target_snapshot_ids
440448
}
441449

442450
unpaused_snapshots: t.Dict[int, t.List[SnapshotId]] = defaultdict(list)
443451
paused_snapshots: t.List[SnapshotId] = []
444452
unrestorable_snapshots: t.List[SnapshotId] = []
445453

446-
for snapshot in snapshots:
454+
for snapshot in same_version_snapshots:
447455
is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids
448456
if is_target_snapshot and not snapshot.unpaused_ts:
449457
logger.info("Unpausing snapshot %s", snapshot.snapshot_id)
@@ -464,8 +472,14 @@ def unpause_snapshots(
464472
snapshot.snapshot_id,
465473
target_snapshot.snapshot_id,
466474
)
475+
full_snapshot = snapshot.full_snapshot
467476
self.remove_intervals(
468-
[(snapshot, snapshot.get_removal_interval(effective_from_ts, current_ts))]
477+
[
478+
(
479+
full_snapshot,
480+
full_snapshot.get_removal_interval(effective_from_ts, current_ts),
481+
)
482+
]
469483
)
470484

471485
if snapshot.unpaused_ts:
@@ -555,7 +569,7 @@ def delete_expired_snapshots(
555569
for snapshot in environment.snapshots
556570
}
557571

558-
def _is_snapshot_used(snapshot: Snapshot) -> bool:
572+
def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
559573
return (
560574
snapshot.snapshot_id in promoted_snapshot_ids
561575
or snapshot.snapshot_id not in expired_candidates
@@ -571,28 +585,26 @@ def _is_snapshot_used(snapshot: Snapshot) -> bool:
571585
snapshots_by_temp_version = defaultdict(set)
572586
for s in snapshots:
573587
snapshots_by_version[(s.name, s.version)].add(s.snapshot_id)
574-
snapshots_by_temp_version[(s.name, s.temp_version_get_or_generate())].add(
575-
s.snapshot_id
576-
)
588+
snapshots_by_temp_version[(s.name, s.temp_version)].add(s.snapshot_id)
577589

578590
expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)]
579591

580592
if expired_snapshots:
581-
self.delete_snapshots(expired_snapshots)
593+
self.delete_snapshots([s.snapshot_id for s in expired_snapshots])
582594

583595
for snapshot in expired_snapshots:
584596
shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)]
585597
shared_version_snapshots.discard(snapshot.snapshot_id)
586598

587599
shared_temp_version_snapshots = snapshots_by_temp_version[
588-
(snapshot.name, snapshot.temp_version_get_or_generate())
600+
(snapshot.name, snapshot.temp_version)
589601
]
590602
shared_temp_version_snapshots.discard(snapshot.snapshot_id)
591603

592604
if not shared_temp_version_snapshots:
593605
cleanup_targets.append(
594606
SnapshotTableCleanupTask(
595-
snapshot=snapshot.table_info,
607+
snapshot=snapshot.full_snapshot.table_info,
596608
dev_table_only=bool(shared_version_snapshots),
597609
)
598610
)
@@ -903,7 +915,7 @@ def _get_snapshots_with_same_version(
903915
self,
904916
snapshots: t.Collection[SnapshotNameVersionLike],
905917
lock_for_update: bool = False,
906-
) -> t.List[Snapshot]:
918+
) -> t.List[SharedVersionSnapshot]:
907919
"""Fetches all snapshots that share the same version as the snapshots.
908920
909921
The output includes the snapshots with the specified identifiers.
@@ -922,7 +934,15 @@ def _get_snapshots_with_same_version(
922934

923935
for where in self._snapshot_name_version_filter(snapshots):
924936
query = (
925-
exp.select("snapshot", "updated_ts", "unpaused_ts", "unrestorable")
937+
exp.select(
938+
"snapshot",
939+
"name",
940+
"identifier",
941+
"version",
942+
"updated_ts",
943+
"unpaused_ts",
944+
"unrestorable",
945+
)
926946
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
927947
.where(where)
928948
)
@@ -932,15 +952,16 @@ def _get_snapshots_with_same_version(
932952
snapshot_rows.extend(self._fetchall(query))
933953

934954
return [
935-
Snapshot(
936-
**{
937-
**json.loads(snapshot),
938-
"updated_ts": updated_ts,
939-
"unpaused_ts": unpaused_ts,
940-
"unrestorable": unrestorable,
941-
}
955+
SharedVersionSnapshot.from_snapshot_record(
956+
name=name,
957+
identifier=identifier,
958+
version=version,
959+
updated_ts=updated_ts,
960+
unpaused_ts=unpaused_ts,
961+
unrestorable=unrestorable,
962+
snapshot=snapshot,
942963
)
943-
for snapshot, updated_ts, unpaused_ts, unrestorable in snapshot_rows
964+
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable in snapshot_rows
944965
]
945966

946967
def _get_versions(self, lock_for_update: bool = False) -> Versions:
@@ -1942,3 +1963,94 @@ def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot:
19421963
if snapshot is None:
19431964
raise KeyError(snapshot_id)
19441965
return snapshot
1966+
1967+
1968+
class SharedVersionSnapshot(PydanticModel):
1969+
"""A stripped down version of a snapshot that is used for fetching snapshots that share the same version
1970+
with a significantly reduced parsing overhead.
1971+
"""
1972+
1973+
name: str
1974+
version: str
1975+
temp_version_: t.Optional[str] = Field(alias="temp_version")
1976+
identifier: str
1977+
fingerprint: SnapshotFingerprint
1978+
interval_unit: IntervalUnit
1979+
change_category: SnapshotChangeCategory
1980+
updated_ts: int
1981+
unpaused_ts: t.Optional[int]
1982+
unrestorable: bool
1983+
disable_restatement: bool
1984+
effective_from: t.Optional[TimeLike]
1985+
raw_snapshot: t.Dict[str, t.Any]
1986+
1987+
@property
1988+
def snapshot_id(self) -> SnapshotId:
1989+
return SnapshotId(name=self.name, identifier=self.identifier)
1990+
1991+
@property
1992+
def is_forward_only(self) -> bool:
1993+
return self.change_category == SnapshotChangeCategory.FORWARD_ONLY
1994+
1995+
@property
1996+
def normalized_effective_from_ts(self) -> t.Optional[int]:
1997+
return (
1998+
to_timestamp(self.interval_unit.cron_floor(self.effective_from))
1999+
if self.effective_from
2000+
else None
2001+
)
2002+
2003+
@property
2004+
def temp_version(self) -> str:
2005+
return self.temp_version_ or self.fingerprint.to_version()
2006+
2007+
@property
2008+
def full_snapshot(self) -> Snapshot:
2009+
return Snapshot(
2010+
**{
2011+
**self.raw_snapshot,
2012+
"updated_ts": self.updated_ts,
2013+
"unpaused_ts": self.unpaused_ts,
2014+
"unrestorable": self.unrestorable,
2015+
}
2016+
)
2017+
2018+
def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None:
2019+
"""Sets the timestamp for when this snapshot was unpaused.
2020+
2021+
Args:
2022+
unpaused_dt: The datetime object of when this snapshot was unpaused.
2023+
"""
2024+
self.unpaused_ts = (
2025+
to_timestamp(self.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None
2026+
)
2027+
2028+
@classmethod
2029+
def from_snapshot_record(
2030+
cls,
2031+
*,
2032+
name: str,
2033+
identifier: str,
2034+
version: str,
2035+
updated_ts: int,
2036+
unpaused_ts: t.Optional[int],
2037+
unrestorable: bool,
2038+
snapshot: str,
2039+
) -> SharedVersionSnapshot:
2040+
raw_snapshot = json.loads(snapshot)
2041+
raw_node = raw_snapshot["node"]
2042+
return SharedVersionSnapshot(
2043+
name=name,
2044+
version=version,
2045+
temp_version=raw_snapshot.get("temp_version"),
2046+
identifier=identifier,
2047+
fingerprint=raw_snapshot["fingerprint"],
2048+
interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])),
2049+
change_category=raw_snapshot["change_category"],
2050+
updated_ts=updated_ts,
2051+
unpaused_ts=unpaused_ts,
2052+
unrestorable=unrestorable,
2053+
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
2054+
effective_from=raw_snapshot.get("effective_from"),
2055+
raw_snapshot=raw_snapshot,
2056+
)

tests/core/test_snapshot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,7 @@ def test_has_paused_forward_only(snapshot: Snapshot):
14001400
snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
14011401
assert has_paused_forward_only([snapshot], [snapshot])
14021402

1403-
snapshot.set_unpaused_ts("2023-01-01")
1403+
snapshot.unpaused_ts = to_timestamp("2023-01-01")
14041404
assert not has_paused_forward_only([snapshot], [snapshot])
14051405

14061406

tests/core/test_state_sync.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,50 @@ def test_unpause_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.
14291429
assert not actual_snapshots[new_snapshot.snapshot_id].unrestorable
14301430

14311431

1432+
def test_unpause_snapshots_hourly(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable):
1433+
snapshot = make_snapshot(
1434+
SqlModel(
1435+
name="test_snapshot",
1436+
query=parse_one("select 1, ds"),
1437+
cron="@hourly",
1438+
),
1439+
)
1440+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
1441+
snapshot.version = "a"
1442+
1443+
assert not snapshot.unpaused_ts
1444+
state_sync.push_snapshots([snapshot])
1445+
1446+
# Unpaused timestamp not aligned with cron
1447+
unpaused_dt = "2022-01-01 01:22:33"
1448+
state_sync.unpause_snapshots([snapshot], unpaused_dt)
1449+
1450+
actual_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id]
1451+
assert actual_snapshot.unpaused_ts
1452+
assert actual_snapshot.unpaused_ts == to_timestamp("2022-01-01 01:00:00")
1453+
1454+
new_snapshot = make_snapshot(
1455+
SqlModel(
1456+
name="test_snapshot",
1457+
query=parse_one("select 2, ds"),
1458+
cron="@daily",
1459+
interval_unit="hour",
1460+
)
1461+
)
1462+
new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
1463+
new_snapshot.version = "a"
1464+
1465+
assert not new_snapshot.unpaused_ts
1466+
state_sync.push_snapshots([new_snapshot])
1467+
state_sync.unpause_snapshots([new_snapshot], unpaused_dt)
1468+
1469+
actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot])
1470+
assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts
1471+
assert actual_snapshots[new_snapshot.snapshot_id].unpaused_ts == to_timestamp(
1472+
"2022-01-01 01:00:00"
1473+
)
1474+
1475+
14321476
def test_unrestorable_snapshot(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable):
14331477
snapshot = make_snapshot(
14341478
SqlModel(
@@ -1529,6 +1573,45 @@ def test_unpause_snapshots_remove_intervals(
15291573
]
15301574

15311575

1576+
def test_unpause_snapshots_remove_intervals_disabled_restatement(
1577+
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable
1578+
):
1579+
kind = dict(name="INCREMENTAL_BY_TIME_RANGE", time_column="ds", disable_restatement=True)
1580+
snapshot = make_snapshot(
1581+
SqlModel(
1582+
name="test_snapshot",
1583+
query=parse_one("select 1, ds"),
1584+
cron="@daily",
1585+
kind=kind,
1586+
),
1587+
version="a",
1588+
)
1589+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
1590+
snapshot.version = "a"
1591+
state_sync.push_snapshots([snapshot])
1592+
state_sync.add_interval(snapshot, "2023-01-01", "2023-01-05")
1593+
1594+
new_snapshot = make_snapshot(
1595+
SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily", kind=kind),
1596+
version="a",
1597+
)
1598+
new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY)
1599+
new_snapshot.version = "a"
1600+
new_snapshot.effective_from = "2023-01-03"
1601+
state_sync.push_snapshots([new_snapshot])
1602+
state_sync.add_interval(snapshot, "2023-01-06", "2023-01-06")
1603+
state_sync.unpause_snapshots([new_snapshot], "2023-01-06")
1604+
1605+
actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot])
1606+
assert actual_snapshots[new_snapshot.snapshot_id].intervals == [
1607+
(to_timestamp("2023-01-01"), to_timestamp("2023-01-03")),
1608+
]
1609+
# The intervals shouldn't have been removed because restatement is disabled
1610+
assert actual_snapshots[snapshot.snapshot_id].intervals == [
1611+
(to_timestamp("2023-01-01"), to_timestamp("2023-01-07")),
1612+
]
1613+
1614+
15321615
def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None:
15331616
from sqlmesh import __version__ as SQLMESH_VERSION
15341617

0 commit comments

Comments
 (0)