Skip to content

Commit b0c623e

Browse files
authored
Fix!: Limit the number of fetched full snapshots when deleting expired snapshots (#5281)
1 parent 7cdbcde commit b0c623e

File tree

6 files changed

+200
-126
lines changed

6 files changed

+200
-126
lines changed

sqlmesh/core/state_sync/base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre
292292

293293
@abc.abstractmethod
294294
def get_expired_snapshots(
295-
self, current_ts: int, ignore_ttl: bool = False
295+
self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False
296296
) -> t.List[SnapshotTableCleanupTask]:
297297
"""Aggregates the id's of the expired snapshots and creates a list of table cleanup tasks.
298298
@@ -341,7 +341,7 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
341341
@abc.abstractmethod
342342
def delete_expired_snapshots(
343343
self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None
344-
) -> t.List[SnapshotTableCleanupTask]:
344+
) -> None:
345345
"""Removes expired snapshots.
346346
347347
Expired snapshots are snapshots that have exceeded their time-to-live
@@ -350,9 +350,6 @@ def delete_expired_snapshots(
350350
Args:
351351
ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting
352352
all snapshots that are not referenced in any environment
353-
354-
Returns:
355-
The list of snapshot table cleanup tasks.
356353
"""
357354

358355
@abc.abstractmethod

sqlmesh/core/state_sync/cache.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
SnapshotId,
99
SnapshotIdLike,
1010
SnapshotInfoLike,
11-
SnapshotTableCleanupTask,
1211
)
1312
from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
1413
from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync
@@ -109,12 +108,10 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
109108

110109
def delete_expired_snapshots(
111110
self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None
112-
) -> t.List[SnapshotTableCleanupTask]:
111+
) -> None:
113112
current_ts = current_ts or now_timestamp()
114113
self.snapshot_cache.clear()
115-
return self.state_sync.delete_expired_snapshots(
116-
current_ts=current_ts, ignore_ttl=ignore_ttl
117-
)
114+
self.state_sync.delete_expired_snapshots(current_ts=current_ts, ignore_ttl=ignore_ttl)
118115

119116
def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
120117
for snapshot_intervals in snapshots_intervals:

sqlmesh/core/state_sync/db/facade.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,9 @@ def invalidate_environment(self, name: str, protect_prod: bool = True) -> None:
262262
self.environment_state.invalidate_environment(name, protect_prod)
263263

264264
def get_expired_snapshots(
265-
self, current_ts: int, ignore_ttl: bool = False
265+
self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False
266266
) -> t.List[SnapshotTableCleanupTask]:
267+
current_ts = current_ts or now_timestamp()
267268
return self.snapshot_state.get_expired_snapshots(
268269
self.environment_state.get_environments(), current_ts=current_ts, ignore_ttl=ignore_ttl
269270
)
@@ -274,16 +275,13 @@ def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary
274275
@transactional()
275276
def delete_expired_snapshots(
276277
self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None
277-
) -> t.List[SnapshotTableCleanupTask]:
278+
) -> None:
278279
current_ts = current_ts or now_timestamp()
279-
expired_snapshot_ids, cleanup_targets = self.snapshot_state._get_expired_snapshots(
280+
for expired_snapshot_ids, cleanup_targets in self.snapshot_state._get_expired_snapshots(
280281
self.environment_state.get_environments(), ignore_ttl=ignore_ttl, current_ts=current_ts
281-
)
282-
283-
self.snapshot_state.delete_snapshots(expired_snapshot_ids)
284-
self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids)
285-
286-
return cleanup_targets
282+
):
283+
self.snapshot_state.delete_snapshots(expired_snapshot_ids)
284+
self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids)
287285

288286
@transactional()
289287
def delete_expired_environments(

sqlmesh/core/state_sync/db/snapshot.py

Lines changed: 45 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
fetchall,
1818
create_batches,
1919
)
20-
from sqlmesh.core.node import IntervalUnit
2120
from sqlmesh.core.environment import Environment
2221
from sqlmesh.core.model import SeedModel, ModelKindName
2322
from sqlmesh.core.snapshot.cache import SnapshotCache
@@ -30,7 +29,6 @@
3029
Snapshot,
3130
SnapshotId,
3231
SnapshotFingerprint,
33-
SnapshotChangeCategory,
3432
)
3533
from sqlmesh.utils.migration import index_text_type, blob_text_type
3634
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
@@ -46,6 +44,9 @@
4644

4745
class SnapshotState:
4846
SNAPSHOT_BATCH_SIZE = 1000
47+
# Use a smaller batch size for expired snapshots to account for fetching
48+
# of all snapshots that share the same version.
49+
EXPIRED_SNAPSHOT_BATCH_SIZE = 200
4950

5051
def __init__(
5152
self,
@@ -63,13 +64,15 @@ def __init__(
6364
"name": exp.DataType.build(index_type),
6465
"identifier": exp.DataType.build(index_type),
6566
"version": exp.DataType.build(index_type),
67+
"dev_version": exp.DataType.build(index_type),
6668
"snapshot": exp.DataType.build(blob_type),
6769
"kind_name": exp.DataType.build("text"),
6870
"updated_ts": exp.DataType.build("bigint"),
6971
"unpaused_ts": exp.DataType.build("bigint"),
7072
"ttl_ms": exp.DataType.build("bigint"),
7173
"unrestorable": exp.DataType.build("boolean"),
7274
"forward_only": exp.DataType.build("boolean"),
75+
"fingerprint": exp.DataType.build(blob_type),
7376
}
7477

7578
self._auto_restatement_columns_to_types = {
@@ -175,19 +178,21 @@ def get_expired_snapshots(
175178
The set of expired snapshot ids.
176179
The list of table cleanup tasks.
177180
"""
178-
_, cleanup_targets = self._get_expired_snapshots(
181+
all_cleanup_targets = []
182+
for _, cleanup_targets in self._get_expired_snapshots(
179183
environments=environments,
180184
current_ts=current_ts,
181185
ignore_ttl=ignore_ttl,
182-
)
183-
return cleanup_targets
186+
):
187+
all_cleanup_targets.extend(cleanup_targets)
188+
return all_cleanup_targets
184189

185190
def _get_expired_snapshots(
186191
self,
187192
environments: t.Iterable[Environment],
188193
current_ts: int,
189194
ignore_ttl: bool = False,
190-
) -> t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]:
195+
) -> t.Iterator[t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]]:
191196
expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table)
192197

193198
if not ignore_ttl:
@@ -202,7 +207,7 @@ def _get_expired_snapshots(
202207
for name, identifier, version in fetchall(self.engine_adapter, expired_query)
203208
}
204209
if not expired_candidates:
205-
return set(), []
210+
return
206211

207212
promoted_snapshot_ids = {
208213
snapshot.snapshot_id
@@ -218,10 +223,8 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
218223

219224
unique_expired_versions = unique(expired_candidates.values())
220225
version_batches = create_batches(
221-
unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE
226+
unique_expired_versions, batch_size=self.EXPIRED_SNAPSHOT_BATCH_SIZE
222227
)
223-
cleanup_targets = []
224-
expired_snapshot_ids = set()
225228
for versions_batch in version_batches:
226229
snapshots = self._get_snapshots_with_same_version(versions_batch)
227230

@@ -232,8 +235,9 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
232235
snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id)
233236

234237
expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)]
235-
expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots])
238+
all_expired_snapshot_ids = {s.snapshot_id for s in expired_snapshots}
236239

240+
cleanup_targets: t.List[t.Tuple[SnapshotId, bool]] = []
237241
for snapshot in expired_snapshots:
238242
shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)]
239243
shared_version_snapshots.discard(snapshot.snapshot_id)
@@ -244,14 +248,30 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
244248
shared_dev_version_snapshots.discard(snapshot.snapshot_id)
245249

246250
if not shared_dev_version_snapshots:
247-
cleanup_targets.append(
248-
SnapshotTableCleanupTask(
249-
snapshot=snapshot.full_snapshot.table_info,
250-
dev_table_only=bool(shared_version_snapshots),
251-
)
251+
dev_table_only = bool(shared_version_snapshots)
252+
cleanup_targets.append((snapshot.snapshot_id, dev_table_only))
253+
254+
snapshot_ids_to_cleanup = [snapshot_id for snapshot_id, _ in cleanup_targets]
255+
for snapshot_id_batch in create_batches(
256+
snapshot_ids_to_cleanup, batch_size=self.SNAPSHOT_BATCH_SIZE
257+
):
258+
snapshot_id_batch_set = set(snapshot_id_batch)
259+
full_snapshots = self._get_snapshots(snapshot_id_batch_set)
260+
cleanup_tasks = [
261+
SnapshotTableCleanupTask(
262+
snapshot=full_snapshots[snapshot_id].table_info,
263+
dev_table_only=dev_table_only,
252264
)
265+
for snapshot_id, dev_table_only in cleanup_targets
266+
if snapshot_id in full_snapshots
267+
]
268+
all_expired_snapshot_ids -= snapshot_id_batch_set
269+
yield snapshot_id_batch_set, cleanup_tasks
253270

254-
return expired_snapshot_ids, cleanup_targets
271+
if all_expired_snapshot_ids:
272+
# Remaining expired snapshots for which there are no tables
273+
# to cleanup
274+
yield all_expired_snapshot_ids, []
255275

256276
def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
257277
"""Deletes snapshots.
@@ -593,14 +613,11 @@ def _get_snapshots_with_same_version(
593613
):
594614
query = (
595615
exp.select(
596-
"snapshot",
597616
"name",
598617
"identifier",
599618
"version",
600-
"updated_ts",
601-
"unpaused_ts",
602-
"unrestorable",
603-
"forward_only",
619+
"dev_version",
620+
"fingerprint",
604621
)
605622
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
606623
.where(where)
@@ -611,17 +628,14 @@ def _get_snapshots_with_same_version(
611628
snapshot_rows.extend(fetchall(self.engine_adapter, query))
612629

613630
return [
614-
SharedVersionSnapshot.from_snapshot_record(
631+
SharedVersionSnapshot(
615632
name=name,
616633
identifier=identifier,
617634
version=version,
618-
updated_ts=updated_ts,
619-
unpaused_ts=unpaused_ts,
620-
unrestorable=unrestorable,
621-
forward_only=forward_only,
622-
snapshot=snapshot,
635+
dev_version=dev_version,
636+
fingerprint=SnapshotFingerprint.parse_raw(fingerprint),
623637
)
624-
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows
638+
for name, identifier, version, dev_version, fingerprint in snapshot_rows
625639
]
626640

627641

@@ -676,6 +690,8 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
676690
"ttl_ms": snapshot.ttl_ms,
677691
"unrestorable": snapshot.unrestorable,
678692
"forward_only": snapshot.forward_only,
693+
"dev_version": snapshot.dev_version,
694+
"fingerprint": snapshot.fingerprint.json(),
679695
}
680696
for snapshot in snapshots
681697
]
@@ -707,76 +723,11 @@ class SharedVersionSnapshot(PydanticModel):
707723
dev_version_: t.Optional[str] = Field(alias="dev_version")
708724
identifier: str
709725
fingerprint: SnapshotFingerprint
710-
interval_unit: IntervalUnit
711-
change_category: SnapshotChangeCategory
712-
updated_ts: int
713-
unpaused_ts: t.Optional[int]
714-
unrestorable: bool
715-
disable_restatement: bool
716-
effective_from: t.Optional[TimeLike]
717-
raw_snapshot: t.Dict[str, t.Any]
718-
forward_only: bool
719726

720727
@property
721728
def snapshot_id(self) -> SnapshotId:
722729
return SnapshotId(name=self.name, identifier=self.identifier)
723730

724-
@property
725-
def is_forward_only(self) -> bool:
726-
return self.forward_only or self.change_category == SnapshotChangeCategory.FORWARD_ONLY
727-
728-
@property
729-
def normalized_effective_from_ts(self) -> t.Optional[int]:
730-
return (
731-
to_timestamp(self.interval_unit.cron_floor(self.effective_from))
732-
if self.effective_from
733-
else None
734-
)
735-
736731
@property
737732
def dev_version(self) -> str:
738733
return self.dev_version_ or self.fingerprint.to_version()
739-
740-
@property
741-
def full_snapshot(self) -> Snapshot:
742-
return Snapshot(
743-
**{
744-
**self.raw_snapshot,
745-
"updated_ts": self.updated_ts,
746-
"unpaused_ts": self.unpaused_ts,
747-
"unrestorable": self.unrestorable,
748-
"forward_only": self.forward_only,
749-
}
750-
)
751-
752-
@classmethod
753-
def from_snapshot_record(
754-
cls,
755-
*,
756-
name: str,
757-
identifier: str,
758-
version: str,
759-
updated_ts: int,
760-
unpaused_ts: t.Optional[int],
761-
unrestorable: bool,
762-
forward_only: bool,
763-
snapshot: str,
764-
) -> SharedVersionSnapshot:
765-
raw_snapshot = json.loads(snapshot)
766-
raw_node = raw_snapshot["node"]
767-
return SharedVersionSnapshot(
768-
name=name,
769-
version=version,
770-
dev_version=raw_snapshot.get("dev_version"),
771-
identifier=identifier,
772-
fingerprint=raw_snapshot["fingerprint"],
773-
interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])),
774-
change_category=raw_snapshot["change_category"],
775-
updated_ts=updated_ts,
776-
unpaused_ts=unpaused_ts,
777-
unrestorable=unrestorable,
778-
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
779-
effective_from=raw_snapshot.get("effective_from"),
780-
raw_snapshot=raw_snapshot,
781-
forward_only=forward_only,
782-
)

0 commit comments

Comments
 (0)