1717 fetchall ,
1818 create_batches ,
1919)
20- from sqlmesh .core .node import IntervalUnit
2120from sqlmesh .core .environment import Environment
2221from sqlmesh .core .model import SeedModel , ModelKindName
2322from sqlmesh .core .snapshot .cache import SnapshotCache
3029 Snapshot ,
3130 SnapshotId ,
3231 SnapshotFingerprint ,
33- SnapshotChangeCategory ,
3432)
3533from sqlmesh .utils .migration import index_text_type , blob_text_type
3634from sqlmesh .utils .date import now_timestamp , TimeLike , to_timestamp
4644
4745class SnapshotState :
4846 SNAPSHOT_BATCH_SIZE = 1000
47+ # Use a smaller batch size for expired snapshots to account for fetching
48+ # of all snapshots that share the same version.
49+ EXPIRED_SNAPSHOT_BATCH_SIZE = 200
4950
5051 def __init__ (
5152 self ,
@@ -63,13 +64,15 @@ def __init__(
6364 "name" : exp .DataType .build (index_type ),
6465 "identifier" : exp .DataType .build (index_type ),
6566 "version" : exp .DataType .build (index_type ),
67+ "dev_version" : exp .DataType .build (index_type ),
6668 "snapshot" : exp .DataType .build (blob_type ),
6769 "kind_name" : exp .DataType .build ("text" ),
6870 "updated_ts" : exp .DataType .build ("bigint" ),
6971 "unpaused_ts" : exp .DataType .build ("bigint" ),
7072 "ttl_ms" : exp .DataType .build ("bigint" ),
7173 "unrestorable" : exp .DataType .build ("boolean" ),
7274 "forward_only" : exp .DataType .build ("boolean" ),
75+ "fingerprint" : exp .DataType .build (blob_type ),
7376 }
7477
7578 self ._auto_restatement_columns_to_types = {
@@ -175,19 +178,21 @@ def get_expired_snapshots(
175178 The set of expired snapshot ids.
176179 The list of table cleanup tasks.
177180 """
178- _ , cleanup_targets = self ._get_expired_snapshots (
181+ all_cleanup_targets = []
182+ for _ , cleanup_targets in self ._get_expired_snapshots (
179183 environments = environments ,
180184 current_ts = current_ts ,
181185 ignore_ttl = ignore_ttl ,
182- )
183- return cleanup_targets
186+ ):
187+ all_cleanup_targets .extend (cleanup_targets )
188+ return all_cleanup_targets
184189
185190 def _get_expired_snapshots (
186191 self ,
187192 environments : t .Iterable [Environment ],
188193 current_ts : int ,
189194 ignore_ttl : bool = False ,
190- ) -> t .Tuple [t .Set [SnapshotId ], t .List [SnapshotTableCleanupTask ]]:
195+ ) -> t .Iterator [ t . Tuple [t .Set [SnapshotId ], t .List [SnapshotTableCleanupTask ] ]]:
191196 expired_query = exp .select ("name" , "identifier" , "version" ).from_ (self .snapshots_table )
192197
193198 if not ignore_ttl :
@@ -202,7 +207,7 @@ def _get_expired_snapshots(
202207 for name , identifier , version in fetchall (self .engine_adapter , expired_query )
203208 }
204209 if not expired_candidates :
205- return set (), []
210+ return
206211
207212 promoted_snapshot_ids = {
208213 snapshot .snapshot_id
@@ -218,10 +223,8 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
218223
219224 unique_expired_versions = unique (expired_candidates .values ())
220225 version_batches = create_batches (
221- unique_expired_versions , batch_size = self .SNAPSHOT_BATCH_SIZE
226+ unique_expired_versions , batch_size = self .EXPIRED_SNAPSHOT_BATCH_SIZE
222227 )
223- cleanup_targets = []
224- expired_snapshot_ids = set ()
225228 for versions_batch in version_batches :
226229 snapshots = self ._get_snapshots_with_same_version (versions_batch )
227230
@@ -232,8 +235,9 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
232235 snapshots_by_dev_version [(s .name , s .dev_version )].add (s .snapshot_id )
233236
234237 expired_snapshots = [s for s in snapshots if not _is_snapshot_used (s )]
235- expired_snapshot_ids . update ([ s .snapshot_id for s in expired_snapshots ])
238+ all_expired_snapshot_ids = { s .snapshot_id for s in expired_snapshots }
236239
240+ cleanup_targets : t .List [t .Tuple [SnapshotId , bool ]] = []
237241 for snapshot in expired_snapshots :
238242 shared_version_snapshots = snapshots_by_version [(snapshot .name , snapshot .version )]
239243 shared_version_snapshots .discard (snapshot .snapshot_id )
@@ -244,14 +248,30 @@ def _is_snapshot_used(snapshot: SharedVersionSnapshot) -> bool:
244248 shared_dev_version_snapshots .discard (snapshot .snapshot_id )
245249
246250 if not shared_dev_version_snapshots :
247- cleanup_targets .append (
248- SnapshotTableCleanupTask (
249- snapshot = snapshot .full_snapshot .table_info ,
250- dev_table_only = bool (shared_version_snapshots ),
251- )
251+ dev_table_only = bool (shared_version_snapshots )
252+ cleanup_targets .append ((snapshot .snapshot_id , dev_table_only ))
253+
254+ snapshot_ids_to_cleanup = [snapshot_id for snapshot_id , _ in cleanup_targets ]
255+ for snapshot_id_batch in create_batches (
256+ snapshot_ids_to_cleanup , batch_size = self .SNAPSHOT_BATCH_SIZE
257+ ):
258+ snapshot_id_batch_set = set (snapshot_id_batch )
259+ full_snapshots = self ._get_snapshots (snapshot_id_batch_set )
260+ cleanup_tasks = [
261+ SnapshotTableCleanupTask (
262+ snapshot = full_snapshots [snapshot_id ].table_info ,
263+ dev_table_only = dev_table_only ,
252264 )
265+ for snapshot_id , dev_table_only in cleanup_targets
266+ if snapshot_id in full_snapshots
267+ ]
268+ all_expired_snapshot_ids -= snapshot_id_batch_set
269+ yield snapshot_id_batch_set , cleanup_tasks
253270
254- return expired_snapshot_ids , cleanup_targets
271+ if all_expired_snapshot_ids :
272+ # Remaining expired snapshots for which there are no tables
273+ # to cleanup
274+ yield all_expired_snapshot_ids , []
255275
256276 def delete_snapshots (self , snapshot_ids : t .Iterable [SnapshotIdLike ]) -> None :
257277 """Deletes snapshots.
@@ -593,14 +613,11 @@ def _get_snapshots_with_same_version(
593613 ):
594614 query = (
595615 exp .select (
596- "snapshot" ,
597616 "name" ,
598617 "identifier" ,
599618 "version" ,
600- "updated_ts" ,
601- "unpaused_ts" ,
602- "unrestorable" ,
603- "forward_only" ,
619+ "dev_version" ,
620+ "fingerprint" ,
604621 )
605622 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
606623 .where (where )
@@ -611,17 +628,14 @@ def _get_snapshots_with_same_version(
611628 snapshot_rows .extend (fetchall (self .engine_adapter , query ))
612629
613630 return [
614- SharedVersionSnapshot . from_snapshot_record (
631+ SharedVersionSnapshot (
615632 name = name ,
616633 identifier = identifier ,
617634 version = version ,
618- updated_ts = updated_ts ,
619- unpaused_ts = unpaused_ts ,
620- unrestorable = unrestorable ,
621- forward_only = forward_only ,
622- snapshot = snapshot ,
635+ dev_version = dev_version ,
636+ fingerprint = SnapshotFingerprint .parse_raw (fingerprint ),
623637 )
624- for snapshot , name , identifier , version , updated_ts , unpaused_ts , unrestorable , forward_only in snapshot_rows
638+ for name , identifier , version , dev_version , fingerprint in snapshot_rows
625639 ]
626640
627641
@@ -676,6 +690,8 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
676690 "ttl_ms" : snapshot .ttl_ms ,
677691 "unrestorable" : snapshot .unrestorable ,
678692 "forward_only" : snapshot .forward_only ,
693+ "dev_version" : snapshot .dev_version ,
694+ "fingerprint" : snapshot .fingerprint .json (),
679695 }
680696 for snapshot in snapshots
681697 ]
@@ -707,76 +723,11 @@ class SharedVersionSnapshot(PydanticModel):
707723 dev_version_ : t .Optional [str ] = Field (alias = "dev_version" )
708724 identifier : str
709725 fingerprint : SnapshotFingerprint
710- interval_unit : IntervalUnit
711- change_category : SnapshotChangeCategory
712- updated_ts : int
713- unpaused_ts : t .Optional [int ]
714- unrestorable : bool
715- disable_restatement : bool
716- effective_from : t .Optional [TimeLike ]
717- raw_snapshot : t .Dict [str , t .Any ]
718- forward_only : bool
719726
720727 @property
721728 def snapshot_id (self ) -> SnapshotId :
722729 return SnapshotId (name = self .name , identifier = self .identifier )
723730
724- @property
725- def is_forward_only (self ) -> bool :
726- return self .forward_only or self .change_category == SnapshotChangeCategory .FORWARD_ONLY
727-
728- @property
729- def normalized_effective_from_ts (self ) -> t .Optional [int ]:
730- return (
731- to_timestamp (self .interval_unit .cron_floor (self .effective_from ))
732- if self .effective_from
733- else None
734- )
735-
736731 @property
737732 def dev_version (self ) -> str :
738733 return self .dev_version_ or self .fingerprint .to_version ()
739-
740- @property
741- def full_snapshot (self ) -> Snapshot :
742- return Snapshot (
743- ** {
744- ** self .raw_snapshot ,
745- "updated_ts" : self .updated_ts ,
746- "unpaused_ts" : self .unpaused_ts ,
747- "unrestorable" : self .unrestorable ,
748- "forward_only" : self .forward_only ,
749- }
750- )
751-
752- @classmethod
753- def from_snapshot_record (
754- cls ,
755- * ,
756- name : str ,
757- identifier : str ,
758- version : str ,
759- updated_ts : int ,
760- unpaused_ts : t .Optional [int ],
761- unrestorable : bool ,
762- forward_only : bool ,
763- snapshot : str ,
764- ) -> SharedVersionSnapshot :
765- raw_snapshot = json .loads (snapshot )
766- raw_node = raw_snapshot ["node" ]
767- return SharedVersionSnapshot (
768- name = name ,
769- version = version ,
770- dev_version = raw_snapshot .get ("dev_version" ),
771- identifier = identifier ,
772- fingerprint = raw_snapshot ["fingerprint" ],
773- interval_unit = raw_node .get ("interval_unit" , IntervalUnit .from_cron (raw_node ["cron" ])),
774- change_category = raw_snapshot ["change_category" ],
775- updated_ts = updated_ts ,
776- unpaused_ts = unpaused_ts ,
777- unrestorable = unrestorable ,
778- disable_restatement = raw_node .get ("kind" , {}).get ("disable_restatement" , False ),
779- effective_from = raw_snapshot .get ("effective_from" ),
780- raw_snapshot = raw_snapshot ,
781- forward_only = forward_only ,
782- )
0 commit comments