diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 6c2097d760..4d3d51a469 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -292,7 +292,7 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre @abc.abstractmethod def get_expired_snapshots( - self, current_ts: int, ignore_ttl: bool = False + self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False ) -> t.List[SnapshotTableCleanupTask]: """Aggregates the id's of the expired snapshots and creates a list of table cleanup tasks. @@ -341,7 +341,7 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: @abc.abstractmethod def delete_expired_snapshots( self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None - ) -> t.List[SnapshotTableCleanupTask]: + ) -> None: """Removes expired snapshots. Expired snapshots are snapshots that have exceeded their time-to-live @@ -350,9 +350,6 @@ def delete_expired_snapshots( Args: ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting all snapshots that are not referenced in any environment - - Returns: - The list of snapshot table cleanup tasks. """ @abc.abstractmethod diff --git a/sqlmesh/core/state_sync/cache.py b/sqlmesh/core/state_sync/cache.py index cc6a0fcb86..8aa5054e13 100644 --- a/sqlmesh/core/state_sync/cache.py +++ b/sqlmesh/core/state_sync/cache.py @@ -8,7 +8,6 @@ SnapshotId, SnapshotIdLike, SnapshotInfoLike, - SnapshotTableCleanupTask, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync @@ -109,12 +108,10 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: def delete_expired_snapshots( self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None - ) -> t.List[SnapshotTableCleanupTask]: + ) -> None: current_ts = current_ts or now_timestamp() self.snapshot_cache.clear() - return self.state_sync.delete_expired_snapshots( - current_ts=current_ts, ignore_ttl=ignore_ttl - ) + self.state_sync.delete_expired_snapshots(current_ts=current_ts, ignore_ttl=ignore_ttl) def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: for snapshot_intervals in snapshots_intervals: diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 858b1aa072..898ba75651 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -262,8 +262,9 @@ def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: self.environment_state.invalidate_environment(name, protect_prod) def get_expired_snapshots( - self, current_ts: int, ignore_ttl: bool = False + self, current_ts: t.Optional[int] = None, ignore_ttl: bool = False ) -> t.List[SnapshotTableCleanupTask]: + current_ts = current_ts or now_timestamp() return self.snapshot_state.get_expired_snapshots( self.environment_state.get_environments(), current_ts=current_ts, ignore_ttl=ignore_ttl ) @@ -274,16 +275,13 @@ def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary @transactional() def delete_expired_snapshots( self, ignore_ttl: bool = False, current_ts: t.Optional[int] = None - ) -> t.List[SnapshotTableCleanupTask]: + ) -> None: current_ts = current_ts or now_timestamp() - expired_snapshot_ids, cleanup_targets = self.snapshot_state._get_expired_snapshots( + for expired_snapshot_ids, cleanup_targets in self.snapshot_state._get_expired_snapshots( self.environment_state.get_environments(), ignore_ttl=ignore_ttl, current_ts=current_ts - ) - - self.snapshot_state.delete_snapshots(expired_snapshot_ids) - self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids) - - return cleanup_targets + ): + self.snapshot_state.delete_snapshots(expired_snapshot_ids) + self.interval_state.cleanup_intervals(cleanup_targets, expired_snapshot_ids) @transactional() def delete_expired_environments( diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py index 8d504993fc..af10f0192e 100644 --- a/sqlmesh/core/state_sync/db/snapshot.py +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -17,7 +17,6 @@ fetchall, create_batches, ) -from sqlmesh.core.node import IntervalUnit from sqlmesh.core.environment import Environment from sqlmesh.core.model import SeedModel, ModelKindName from sqlmesh.core.snapshot.cache import SnapshotCache @@ -30,7 +29,6 @@ Snapshot, SnapshotId, SnapshotFingerprint, - SnapshotChangeCategory, ) from sqlmesh.utils.migration import index_text_type, blob_text_type from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp @@ -46,6 +44,9 @@ class SnapshotState: SNAPSHOT_BATCH_SIZE = 1000 + # Use a smaller batch size for expired snapshots to account for fetching + # of all snapshots that share the same version. + EXPIRED_SNAPSHOT_BATCH_SIZE = 200 def __init__( self, @@ -63,6 +64,7 @@ def __init__( "name": exp.DataType.build(index_type), "identifier": exp.DataType.build(index_type), "version": exp.DataType.build(index_type), + "dev_version": exp.DataType.build(index_type), "snapshot": exp.DataType.build(blob_type), "kind_name": exp.DataType.build("text"), "updated_ts": exp.DataType.build("bigint"), @@ -70,6 +72,7 @@ def __init__( "ttl_ms": exp.DataType.build("bigint"), "unrestorable": exp.DataType.build("boolean"), "forward_only": exp.DataType.build("boolean"), + "fingerprint": exp.DataType.build(blob_type), } self._auto_restatement_columns_to_types = { @@ -175,19 +178,21 @@ def get_expired_snapshots( The set of expired snapshot ids. The list of table cleanup tasks. """ - _, cleanup_targets = self._get_expired_snapshots( + all_cleanup_targets = [] + for _, cleanup_targets in self._get_expired_snapshots( environments=environments, current_ts=current_ts, ignore_ttl=ignore_ttl, - ) - return cleanup_targets + ): + all_cleanup_targets.extend(cleanup_targets) + return all_cleanup_targets def _get_expired_snapshots( self, environments: t.Iterable[Environment], current_ts: int, ignore_ttl: bool = False, - ) -> t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]: + ) -> t.Iterator[t.Tuple[t.Set[SnapshotId], t.List[SnapshotTableCleanupTask]]]: expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) if not ignore_ttl: @@ -202,7 +207,7 @@ def _get_expired_snapshots( for name, identifier, version in fetchall(self.engine_adapter, expired_query) } if not expired_candidates: - return set(), [] + return promoted_snapshot_ids = { snapshot.snapshot_id @@ -218,10 +223,8 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: unique_expired_versions = unique(expired_candidates.values()) version_batches = create_batches( - unique_expired_versions, batch_size=self.SNAPSHOT_BATCH_SIZE + unique_expired_versions, batch_size=self.EXPIRED_SNAPSHOT_BATCH_SIZE ) - cleanup_targets = [] - expired_snapshot_ids = set() for versions_batch in version_batches: snapshots = self._get_snapshots_with_same_version(versions_batch) @@ -232,8 +235,9 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] - expired_snapshot_ids.update([s.snapshot_id for s in expired_snapshots]) + all_expired_snapshot_ids = {s.snapshot_id for s in expired_snapshots} + cleanup_targets: t.List[t.Tuple[SnapshotId, bool]] = [] for snapshot in expired_snapshots: shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] shared_version_snapshots.discard(snapshot.snapshot_id) @@ -244,14 +248,30 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool: shared_dev_version_snapshots.discard(snapshot.snapshot_id) if not shared_dev_version_snapshots: - cleanup_targets.append( - SnapshotTableCleanupTask( - snapshot=snapshot.full_snapshot.table_info, - dev_table_only=bool(shared_version_snapshots), - ) + dev_table_only = bool(shared_version_snapshots) + cleanup_targets.append((snapshot.snapshot_id, dev_table_only)) + + snapshot_ids_to_cleanup = [snapshot_id for snapshot_id, _ in cleanup_targets] + for snapshot_id_batch in create_batches( + snapshot_ids_to_cleanup, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + snapshot_id_batch_set = set(snapshot_id_batch) + full_snapshots = self._get_snapshots(snapshot_id_batch_set) + cleanup_tasks = [ + SnapshotTableCleanupTask( + snapshot=full_snapshots[snapshot_id].table_info, + dev_table_only=dev_table_only, ) + for snapshot_id, dev_table_only in cleanup_targets + if snapshot_id in full_snapshots + ] + all_expired_snapshot_ids -= snapshot_id_batch_set + yield snapshot_id_batch_set, cleanup_tasks - return expired_snapshot_ids, cleanup_targets + if all_expired_snapshot_ids: + # Remaining expired snapshots for which there are no tables + # to cleanup + yield all_expired_snapshot_ids, [] def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: """Deletes snapshots. @@ -593,14 +613,11 @@ def _get_snapshots_with_same_version( ): query = ( exp.select( - "snapshot", "name", "identifier", "version", - "updated_ts", - "unpaused_ts", - "unrestorable", - "forward_only", + "dev_version", + "fingerprint", ) .from_(exp.to_table(self.snapshots_table).as_("snapshots")) .where(where) @@ -611,17 +628,14 @@ def _get_snapshots_with_same_version( snapshot_rows.extend(fetchall(self.engine_adapter, query)) return [ - SharedVersionSnapshot.from_snapshot_record( + SharedVersionSnapshot( name=name, identifier=identifier, version=version, - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - forward_only=forward_only, - snapshot=snapshot, + dev_version=dev_version, + fingerprint=SnapshotFingerprint.parse_raw(fingerprint), ) - for snapshot, name, identifier, version, updated_ts, unpaused_ts, unrestorable, forward_only in snapshot_rows + for name, identifier, version, dev_version, fingerprint in snapshot_rows ] @@ -676,6 +690,8 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: "ttl_ms": snapshot.ttl_ms, "unrestorable": snapshot.unrestorable, "forward_only": snapshot.forward_only, + "dev_version": snapshot.dev_version, + "fingerprint": snapshot.fingerprint.json(), } for snapshot in snapshots ] @@ -707,76 +723,11 @@ class SharedVersionSnapshot(PydanticModel): dev_version_: t.Optional[str] = Field(alias="dev_version") identifier: str fingerprint: SnapshotFingerprint - interval_unit: IntervalUnit - change_category: SnapshotChangeCategory - updated_ts: int - unpaused_ts: t.Optional[int] - unrestorable: bool - disable_restatement: bool - effective_from: t.Optional[TimeLike] - raw_snapshot: t.Dict[str, t.Any] - forward_only: bool @property def snapshot_id(self) -> SnapshotId: return SnapshotId(name=self.name, identifier=self.identifier) - @property - def is_forward_only(self) -> bool: - return self.forward_only or self.change_category == SnapshotChangeCategory.FORWARD_ONLY - - @property - def normalized_effective_from_ts(self) -> t.Optional[int]: - return ( - to_timestamp(self.interval_unit.cron_floor(self.effective_from)) - if self.effective_from - else None - ) - @property def dev_version(self) -> str: return self.dev_version_ or self.fingerprint.to_version() - - @property - def full_snapshot(self) -> Snapshot: - return Snapshot( - **{ - **self.raw_snapshot, - "updated_ts": self.updated_ts, - "unpaused_ts": self.unpaused_ts, - "unrestorable": self.unrestorable, - "forward_only": self.forward_only, - } - ) - - @classmethod - def from_snapshot_record( - cls, - *, - name: str, - identifier: str, - version: str, - updated_ts: int, - unpaused_ts: t.Optional[int], - unrestorable: bool, - forward_only: bool, - snapshot: str, - ) -> SharedVersionSnapshot: - raw_snapshot = json.loads(snapshot) - raw_node = raw_snapshot["node"] - return SharedVersionSnapshot( - name=name, - version=version, - dev_version=raw_snapshot.get("dev_version"), - identifier=identifier, - fingerprint=raw_snapshot["fingerprint"], - interval_unit=raw_node.get("interval_unit", IntervalUnit.from_cron(raw_node["cron"])), - change_category=raw_snapshot["change_category"], - updated_ts=updated_ts, - unpaused_ts=unpaused_ts, - unrestorable=unrestorable, - disable_restatement=raw_node.get("kind", {}).get("disable_restatement", False), - effective_from=raw_snapshot.get("effective_from"), - raw_snapshot=raw_snapshot, - forward_only=forward_only, - ) diff --git a/sqlmesh/migrations/v0094_add_dev_version_and_fingerprint_columns.py b/sqlmesh/migrations/v0094_add_dev_version_and_fingerprint_columns.py new file mode 100644 index 0000000000..0163b36ab4 --- /dev/null +++ b/sqlmesh/migrations/v0094_add_dev_version_and_fingerprint_columns.py @@ -0,0 +1,116 @@ +"""Add dev_version and fingerprint columns to the snapshots table.""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate(state_sync, **kwargs): # type: ignore + import pandas as pd + + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + add_dev_version_exp = exp.Alter( + this=exp.to_table(snapshots_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("dev_version"), + kind=exp.DataType.build(index_type), + ) + ], + ) + engine_adapter.execute(add_dev_version_exp) + + add_fingerprint_exp = exp.Alter( + this=exp.to_table(snapshots_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("fingerprint"), + kind=exp.DataType.build(blob_type), + ) + ], + ) + engine_adapter.execute(add_fingerprint_exp) + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + forward_only, + _, + _, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + "forward_only", + "dev_version", + "fingerprint", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": snapshot, + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + "forward_only": forward_only, + "dev_version": parsed_snapshot.get("dev_version"), + "fingerprint": json.dumps(parsed_snapshot.get("fingerprint")), + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + "forward_only": exp.DataType.build("boolean"), + "dev_version": exp.DataType.build(index_type), + "fingerprint": exp.DataType.build(blob_type), + }, + ) diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index e7046be13d..be8e4ad3e0 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -1156,10 +1156,11 @@ def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snaps new_snapshot.snapshot_id, } - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), ] + state_sync.delete_expired_snapshots() assert not state_sync.get_snapshots(all_snapshots) @@ -1186,9 +1187,10 @@ def test_delete_expired_snapshots_seed( state_sync.push_snapshots(all_snapshots) assert set(state_sync.get_snapshots(all_snapshots)) == {snapshot.snapshot_id} - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False), ] + state_sync.delete_expired_snapshots() assert not state_sync.get_snapshots(all_snapshots) @@ -1226,10 +1228,11 @@ def test_delete_expired_snapshots_batching( snapshot_b.snapshot_id, } - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=False), SnapshotTableCleanupTask(snapshot=snapshot_b.table_info, dev_table_only=False), ] + state_sync.delete_expired_snapshots() assert not state_sync.get_snapshots(all_snapshots) @@ -1262,7 +1265,8 @@ def test_delete_expired_snapshots_promoted( state_sync.promote(env) all_snapshots = [snapshot] - assert not state_sync.delete_expired_snapshots() + assert not state_sync.get_expired_snapshots() + state_sync.delete_expired_snapshots() assert set(state_sync.get_snapshots(all_snapshots)) == {snapshot.snapshot_id} env.snapshots_ = [] @@ -1271,9 +1275,10 @@ def test_delete_expired_snapshots_promoted( now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.db.facade.now_timestamp") now_timestamp_mock.return_value = now_timestamp() + 11000 - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False) ] + state_sync.delete_expired_snapshots() assert not state_sync.get_snapshots(all_snapshots) @@ -1310,9 +1315,10 @@ def test_delete_expired_snapshots_dev_table_cleanup_only( new_snapshot.snapshot_id, } - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True) ] + state_sync.delete_expired_snapshots() assert set(state_sync.get_snapshots(all_snapshots)) == {new_snapshot.snapshot_id} @@ -1351,7 +1357,8 @@ def test_delete_expired_snapshots_shared_dev_table( new_snapshot.snapshot_id, } - assert not state_sync.delete_expired_snapshots() # No dev table cleanup + assert not state_sync.get_expired_snapshots() # No dev table cleanup + state_sync.delete_expired_snapshots() assert set(state_sync.get_snapshots(all_snapshots)) == {new_snapshot.snapshot_id} @@ -1396,13 +1403,17 @@ def test_delete_expired_snapshots_ignore_ttl( state_sync.promote(env) # default TTL = 1 week, nothing to clean up yet if we take TTL into account - assert not state_sync.delete_expired_snapshots() + assert not state_sync.get_expired_snapshots() + state_sync.delete_expired_snapshots() + assert state_sync.snapshots_exist([snapshot_c.snapshot_id]) == {snapshot_c.snapshot_id} # If we ignore TTL, only snapshot_c should get cleaned up because snapshot_a and snapshot_b are part of an environment assert snapshot_a.table_info != snapshot_b.table_info != snapshot_c.table_info - assert state_sync.delete_expired_snapshots(ignore_ttl=True) == [ + assert state_sync.get_expired_snapshots(ignore_ttl=True) == [ SnapshotTableCleanupTask(snapshot=snapshot_c.table_info, dev_table_only=False) ] + state_sync.delete_expired_snapshots(ignore_ttl=True) + assert not state_sync.snapshots_exist([snapshot_c.snapshot_id]) def test_delete_expired_snapshots_cleanup_intervals( @@ -1465,10 +1476,11 @@ def test_delete_expired_snapshots_cleanup_intervals( ] assert not stored_new_snapshot.dev_intervals - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), ] + state_sync.delete_expired_snapshots() assert not get_snapshot_intervals(snapshot) @@ -1552,9 +1564,10 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_version( ) # Delete the expired snapshot - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), ] + state_sync.delete_expired_snapshots() assert not state_sync.get_snapshots([snapshot]) # Check new snapshot's intervals @@ -1671,7 +1684,8 @@ def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( ) # Delete the expired snapshot - assert state_sync.delete_expired_snapshots() == [] + assert state_sync.get_expired_snapshots() == [] + state_sync.delete_expired_snapshots() assert not state_sync.get_snapshots([snapshot]) # Check new snapshot's intervals @@ -1764,9 +1778,10 @@ def test_compact_intervals_after_cleanup( state_sync.add_interval(snapshot_c, "2023-01-07", "2023-01-09", is_dev=True) # Only the dev table of the original snapshot should be deleted - assert state_sync.delete_expired_snapshots() == [ + assert state_sync.get_expired_snapshots() == [ SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=True), ] + state_sync.delete_expired_snapshots() assert state_sync.engine_adapter.fetchone("SELECT COUNT(*) FROM sqlmesh._intervals")[0] == 5 # type: ignore