Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre

@abc.abstractmethod
def get_expired_snapshots(
self, current_ts: int, ignore_ttl: bool = False
self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False
) -> t.List[SnapshotTableCleanupTask]:
"""Aggregates the id's of the expired snapshots and creates a list of table cleanup tasks.

Expand Down Expand Up @@ -341,7 +341,7 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
@abc.abstractmethod
def delete_expired_snapshots(
self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None
) -> t.List[SnapshotTableCleanupTask]:
) -> None:
"""Removes expired snapshots.

Expired snapshots are snapshots that have exceeded their time-to-live
Expand All @@ -350,9 +350,6 @@ def delete_expired_snapshots(
Args:
ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting
all snapshots that are not referenced in any environment

Returns:
The list of snapshot table cleanup tasks.
"""

@abc.abstractmethod
Expand Down
7 changes: 2 additions & 5 deletions sqlmesh/core/state_sync/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
SnapshotId,
SnapshotIdLike,
SnapshotInfoLike,
SnapshotTableCleanupTask,
)
from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync
Expand Down Expand Up @@ -109,12 +108,10 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:

def delete_expired_snapshots(
self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None
) -> t.List[SnapshotTableCleanupTask]:
) -> None:
current_ts = current_ts or now_timestamp()
self.snapshot_cache.clear()
return self.state_sync.delete_expired_snapshots(
current_ts=current_ts, ignore_ttl=ignore_ttl
)
self.state_sync.delete_expired_snapshots(current_ts=current_ts, ignore_ttl=ignore_ttl)

def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None:
for snapshot_intervals in snapshots_intervals:
Expand Down
16 changes: 7 additions & 9 deletions sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,9 @@ def invalidate_environment(self, name: str, protect_prod: bool = True) -> None:
self.environment_state.invalidate_environment(name, protect_prod)

def get_expired_snapshots(
self, current_ts: int, ignore_ttl: bool = False
self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False
) -> t.List[SnapshotTableCleanupTask]:
current_ts = current_ts or now_timestamp()
return self.snapshot_state.get_expired_snapshots(
self.environment_state.get_environments(), current_ts=current_ts, ignore_ttl=ignore_ttl
)
Expand All @@ -274,16 +275,13 @@ def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary
@transactional()
def delete_expired_snapshots(
self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None
) -> t.List[SnapshotTableCleanupTask]:
) -> None:
current_ts = current_ts or now_timestamp()
expired_snapshot_ids, cleanup_targets = self.snapshot_state._get_expired_snapshots(
for expired_snapshot_ids, cleanup_targets in self.snapshot_state._get_expired_snapshots(
self.environment_state.get_environments(), ignore_ttl=ignore_ttl, current_ts=current_ts
)

self.snapshot_state.delete_snapshots(expired_snapshot_ids)
self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids)

return cleanup_targets
):
self.snapshot_state.delete_snapshots(expired_snapshot_ids)
self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids)

@transactional()
def delete_expired_environments(
Expand Down
139 changes: 45 additions & 94 deletions sqlmesh/core/state_sync/db/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
fetchall,
create_batches,
)
from sqlmesh.core.node import IntervalUnit
from sqlmesh.core.environment import Environment
from sqlmesh.core.model import SeedModel, ModelKindName
from sqlmesh.core.snapshot.cache import SnapshotCache
Expand All @@ -30,7 +29,6 @@
Snapshot,
SnapshotId,
SnapshotFingerprint,
SnapshotChangeCategory,
)
from sqlmesh.utils.migration import index_text_type, blob_text_type
from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp
Expand All @@ -46,6 +44,9 @@

class SnapshotState:
SNAPSHOT_BATCH_SIZE = 1000
# Use a smaller batch size for expired snapshots to account for fetching
# of all snapshots that share the same version.
EXPIRED_SNAPSHOT_BATCH_SIZE = 200

def __init__(
self,
Expand All @@ -63,13 +64,15 @@ def __init__(
"name": exp.DataType.build(index_type),
"identifier": exp.DataType.build(index_type),
"version": exp.DataType.build(index_type),
"dev_version": exp.DataType.build(index_type),
"snapshot": exp.DataType.build(blob_type),
"kind_name": exp.DataType.build("text"),
"updated_ts": exp.DataType.build("bigint"),
"unpaused_ts": exp.DataType.build("bigint"),
"ttl_ms": exp.DataType.build("bigint"),
"unrestorable": exp.DataType.build("boolean"),
"forward_only": exp.DataType.build("boolean"),
"fingerprint": exp.DataType.build(blob_type),
}

self._auto_restatement_columns_to_types = {
Expand Down Expand Up @@ -175,19 +178,21 @@ def get_expired_snapshots(
The set of expired snapshot ids.
The list of table cleanup tasks.
"""
_, cleanup_targets = self._get_expired_snapshots(
all_cleanup_targets = []
for _, cleanup_targets in self._get_expired_snapshots(
environments=environments,
current_ts=current_ts,
ignore_ttl=ignore_ttl,
)
return cleanup_targets
):
all_cleanup_targets.extend(cleanup_targets)
return all_cleanup_targets

def _get_expired_snapshots(
self,
environments: t.Iterable[Environment],
current_ts: int,
ignore_ttl: bool = False,
) -> t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]:
) -> t.Iterator[t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]]:
expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table)

if not ignore_ttl:
Expand All @@ -202,7 +207,7 @@ def _get_expired_snapshots(
for name, identifier, version in fetchall(self.engine_adapter, expired_query)
}
if not expired_candidates:
return set(), []
return

promoted_snapshot_ids = {
snapshot.snapshot_id
Expand All @@ -218,10 +223,8 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:

unique_expired_versions = unique(expired_candidates.values())
version_batches = create_batches(
unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE
unique_expired_versions, batch_size=self.EXPIRED_SNAPSHOT_BATCH_SIZE
)
cleanup_targets = []
expired_snapshot_ids = set()
for versions_batch in version_batches:
snapshots = self._get_snapshots_with_same_version(versions_batch)

Expand All @@ -232,8 +235,9 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id)

expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)]
expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots])
all_expired_snapshot_ids = {s.snapshot_id for s in expired_snapshots}

cleanup_targets: t.List[t.Tuple[SnapshotId, bool]] = []
for snapshot in expired_snapshots:
shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)]
shared_version_snapshots.discard(snapshot.snapshot_id)
Expand All @@ -244,14 +248,30 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
shared_dev_version_snapshots.discard(snapshot.snapshot_id)

if not shared_dev_version_snapshots:
cleanup_targets.append(
SnapshotTableCleanupTask(
snapshot=snapshot.full_snapshot.table_info,
dev_table_only=bool(shared_version_snapshots),
)
dev_table_only = bool(shared_version_snapshots)
cleanup_targets.append((snapshot.snapshot_id, dev_table_only))

snapshot_ids_to_cleanup = [snapshot_id for snapshot_id, _ in cleanup_targets]
for snapshot_id_batch in create_batches(
snapshot_ids_to_cleanup, batch_size=self.SNAPSHOT_BATCH_SIZE
):
snapshot_id_batch_set = set(snapshot_id_batch)
full_snapshots = self._get_snapshots(snapshot_id_batch_set)
cleanup_tasks = [
SnapshotTableCleanupTask(
snapshot=full_snapshots[snapshot_id].table_info,
dev_table_only=dev_table_only,
)
for snapshot_id, dev_table_only in cleanup_targets
if snapshot_id in full_snapshots
]
all_expired_snapshot_ids -= snapshot_id_batch_set
yield snapshot_id_batch_set, cleanup_tasks

return expired_snapshot_ids, cleanup_targets
if all_expired_snapshot_ids:
# Remaining expired snapshots for which there are no tables
# to cleanup
yield all_expired_snapshot_ids, []

def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None:
"""Deletes snapshots.
Expand Down Expand Up @@ -593,14 +613,11 @@ def _get_snapshots_with_same_version(
):
query = (
exp.select(
"snapshot",
"name",
"identifier",
"version",
"updated_ts",
"unpaused_ts",
"unrestorable",
"forward_only",
"dev_version",
"fingerprint",
)
.from_(exp.to_table(self.snapshots_table).as_("snapshots"))
.where(where)
Expand All @@ -611,17 +628,14 @@ def _get_snapshots_with_same_version(
snapshot_rows.extend(fetchall(self.engine_adapter, query))

return [
SharedVersionSnapshot.from_snapshot_record(
SharedVersionSnapshot(
name=name,
identifier=identifier,
version=version,
updated_ts=updated_ts,
unpaused_ts=unpaused_ts,
unrestorable=unrestorable,
forward_only=forward_only,
snapshot=snapshot,
dev_version=dev_version,
fingerprint=SnapshotFingerprint.parse_raw(fingerprint),
)
for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows
for name, identifier, version, dev_version, fingerprint in snapshot_rows
]


Expand Down Expand Up @@ -676,6 +690,8 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
"ttl_ms": snapshot.ttl_ms,
"unrestorable": snapshot.unrestorable,
"forward_only": snapshot.forward_only,
"dev_version": snapshot.dev_version,
"fingerprint": snapshot.fingerprint.json(),
}
for snapshot in snapshots
]
Expand Down Expand Up @@ -707,76 +723,11 @@ class SharedVersionSnapshot(PydanticModel):
dev_version_: t.Optional[str] = Field(alias="dev_version")
identifier: str
fingerprint: SnapshotFingerprint
interval_unit: IntervalUnit
change_category: SnapshotChangeCategory
updated_ts: int
unpaused_ts: t.Optional[int]
unrestorable: bool
disable_restatement: bool
effective_from: t.Optional[TimeLike]
raw_snapshot: t.Dict[str, t.Any]
forward_only: bool

@property
def snapshot_id(self) -> SnapshotId:
return SnapshotId(name=self.name, identifier=self.identifier)

@property
def is_forward_only(self) -> bool:
return self.forward_only or self.change_category == SnapshotChangeCategory.FORWARD_ONLY

@property
def normalized_effective_from_ts(self) -> t.Optional[int]:
return (
to_timestamp(self.interval_unit.cron_floor(self.effective_from))
if self.effective_from
else None
)

@property
def dev_version(self) -> str:
return self.dev_version_ or self.fingerprint.to_version()

@property
def full_snapshot(self) -> Snapshot:
return Snapshot(
**{
**self.raw_snapshot,
"updated_ts": self.updated_ts,
"unpaused_ts": self.unpaused_ts,
"unrestorable": self.unrestorable,
"forward_only": self.forward_only,
}
)

@classmethod
def from_snapshot_record(
cls,
*,
name: str,
identifier: str,
version: str,
updated_ts: int,
unpaused_ts: t.Optional[int],
unrestorable: bool,
forward_only: bool,
snapshot: str,
) -> SharedVersionSnapshot:
raw_snapshot = json.loads(snapshot)
raw_node = raw_snapshot["node"]
return SharedVersionSnapshot(
name=name,
version=version,
dev_version=raw_snapshot.get("dev_version"),
identifier=identifier,
fingerprint=raw_snapshot["fingerprint"],
interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])),
change_category=raw_snapshot["change_category"],
updated_ts=updated_ts,
unpaused_ts=unpaused_ts,
unrestorable=unrestorable,
disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False),
effective_from=raw_snapshot.get("effective_from"),
raw_snapshot=raw_snapshot,
forward_only=forward_only,
)
Loading