2727from datetime import datetime
2828
2929import pandas as pd
30+ from pydantic import Field
3031from sqlglot import __version__ as SQLGLOT_VERSION
3132from sqlglot import exp
3233from sqlglot .helper import seq_get
3738from sqlmesh .core .engine_adapter import EngineAdapter
3839from sqlmesh .core .environment import Environment
3940from sqlmesh .core .model import ModelKindName , SeedModel
41+ from sqlmesh .core .node import IntervalUnit
4042from sqlmesh .core .snapshot import (
4143 Intervals ,
4244 Node ,
4345 Snapshot ,
46+ SnapshotChangeCategory ,
4447 SnapshotFingerprint ,
4548 SnapshotId ,
4649 SnapshotIdLike ,
7376from sqlmesh .utils .date import TimeLike , now , now_timestamp , time_like_to_str , to_timestamp
7477from sqlmesh .utils .errors import ConflictingPlanError , SQLMeshError
7578from sqlmesh .utils .migration import blob_text_type , index_text_type
79+ from sqlmesh .utils .pydantic import PydanticModel
7680
7781logger = logging .getLogger (__name__ )
7882
@@ -434,16 +438,20 @@ def unpause_snapshots(
434438 current_ts = now ()
435439
436440 target_snapshot_ids = {s .snapshot_id for s in snapshots }
437- snapshots = self ._get_snapshots_with_same_version (snapshots , lock_for_update = True )
441+ same_version_snapshots = self ._get_snapshots_with_same_version (
442+ snapshots , lock_for_update = True
443+ )
438444 target_snapshots_by_version = {
439- (s .name , s .version ): s for s in snapshots if s .snapshot_id in target_snapshot_ids
445+ (s .name , s .version ): s
446+ for s in same_version_snapshots
447+ if s .snapshot_id in target_snapshot_ids
440448 }
441449
442450 unpaused_snapshots : t .Dict [int , t .List [SnapshotId ]] = defaultdict (list )
443451 paused_snapshots : t .List [SnapshotId ] = []
444452 unrestorable_snapshots : t .List [SnapshotId ] = []
445453
446- for snapshot in snapshots :
454+ for snapshot in same_version_snapshots :
447455 is_target_snapshot = snapshot .snapshot_id in target_snapshot_ids
448456 if is_target_snapshot and not snapshot .unpaused_ts :
449457 logger .info ("Unpausing snapshot %s" , snapshot .snapshot_id )
@@ -464,8 +472,14 @@ def unpause_snapshots(
464472 snapshot .snapshot_id ,
465473 target_snapshot .snapshot_id ,
466474 )
475+ full_snapshot = snapshot .full_snapshot
467476 self .remove_intervals (
468- [(snapshot , snapshot .get_removal_interval (effective_from_ts , current_ts ))]
477+ [
478+ (
479+ full_snapshot ,
480+ full_snapshot .get_removal_interval (effective_from_ts , current_ts ),
481+ )
482+ ]
469483 )
470484
471485 if snapshot .unpaused_ts :
@@ -555,7 +569,7 @@ def delete_expired_snapshots(
555569 for snapshot in environment .snapshots
556570 }
557571
558- def _is_snapshot_used (snapshot : Snapshot ) -> bool :
572+ def _is_snapshot_used (snapshot : SharedVersionSnapshot ) -> bool :
559573 return (
560574 snapshot .snapshot_id in promoted_snapshot_ids
561575 or snapshot .snapshot_id not in expired_candidates
@@ -571,28 +585,26 @@ def _is_snapshot_used(snapshot: Snapshot) -> bool:
571585 snapshots_by_temp_version = defaultdict (set )
572586 for s in snapshots :
573587 snapshots_by_version [(s .name , s .version )].add (s .snapshot_id )
574- snapshots_by_temp_version [(s .name , s .temp_version_get_or_generate ())].add (
575- s .snapshot_id
576- )
588+ snapshots_by_temp_version [(s .name , s .temp_version )].add (s .snapshot_id )
577589
578590 expired_snapshots = [s for s in snapshots if not _is_snapshot_used (s )]
579591
580592 if expired_snapshots :
581- self .delete_snapshots (expired_snapshots )
593+ self .delete_snapshots ([ s . snapshot_id for s in expired_snapshots ] )
582594
583595 for snapshot in expired_snapshots :
584596 shared_version_snapshots = snapshots_by_version [(snapshot .name , snapshot .version )]
585597 shared_version_snapshots .discard (snapshot .snapshot_id )
586598
587599 shared_temp_version_snapshots = snapshots_by_temp_version [
588- (snapshot .name , snapshot .temp_version_get_or_generate () )
600+ (snapshot .name , snapshot .temp_version )
589601 ]
590602 shared_temp_version_snapshots .discard (snapshot .snapshot_id )
591603
592604 if not shared_temp_version_snapshots :
593605 cleanup_targets .append (
594606 SnapshotTableCleanupTask (
595- snapshot = snapshot .table_info ,
607+ snapshot = snapshot .full_snapshot . table_info ,
596608 dev_table_only = bool (shared_version_snapshots ),
597609 )
598610 )
@@ -903,7 +915,7 @@ def _get_snapshots_with_same_version(
903915 self ,
904916 snapshots : t .Collection [SnapshotNameVersionLike ],
905917 lock_for_update : bool = False ,
906- ) -> t .List [Snapshot ]:
918+ ) -> t .List [SharedVersionSnapshot ]:
907919 """Fetches all snapshots that share the same version as the snapshots.
908920
909921 The output includes the snapshots with the specified identifiers.
@@ -922,7 +934,15 @@ def _get_snapshots_with_same_version(
922934
923935 for where in self ._snapshot_name_version_filter (snapshots ):
924936 query = (
925- exp .select ("snapshot" , "updated_ts" , "unpaused_ts" , "unrestorable" )
937+ exp .select (
938+ "snapshot" ,
939+ "name" ,
940+ "identifier" ,
941+ "version" ,
942+ "updated_ts" ,
943+ "unpaused_ts" ,
944+ "unrestorable" ,
945+ )
926946 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
927947 .where (where )
928948 )
@@ -932,15 +952,16 @@ def _get_snapshots_with_same_version(
932952 snapshot_rows .extend (self ._fetchall (query ))
933953
934954 return [
935- Snapshot (
936- ** {
937- ** json .loads (snapshot ),
938- "updated_ts" : updated_ts ,
939- "unpaused_ts" : unpaused_ts ,
940- "unrestorable" : unrestorable ,
941- }
955+ SharedVersionSnapshot .from_snapshot_record (
956+ name = name ,
957+ identifier = identifier ,
958+ version = version ,
959+ updated_ts = updated_ts ,
960+ unpaused_ts = unpaused_ts ,
961+ unrestorable = unrestorable ,
962+ snapshot = snapshot ,
942963 )
943- for snapshot , updated_ts , unpaused_ts , unrestorable in snapshot_rows
964+ for snapshot , name , identifier , version , updated_ts , unpaused_ts , unrestorable in snapshot_rows
944965 ]
945966
946967 def _get_versions (self , lock_for_update : bool = False ) -> Versions :
@@ -1942,3 +1963,94 @@ def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot:
19421963 if snapshot is None :
19431964 raise KeyError (snapshot_id )
19441965 return snapshot
1966+
1967+
1968+ class SharedVersionSnapshot (PydanticModel ):
1969+ """A stripped down version of a snapshot that is used for fetching snapshots that share the same version
1970+ with a significantly reduced parsing overhead.
1971+ """
1972+
1973+ name : str
1974+ version : str
1975+ temp_version_ : t .Optional [str ] = Field (alias = "temp_version" )
1976+ identifier : str
1977+ fingerprint : SnapshotFingerprint
1978+ interval_unit : IntervalUnit
1979+ change_category : SnapshotChangeCategory
1980+ updated_ts : int
1981+ unpaused_ts : t .Optional [int ]
1982+ unrestorable : bool
1983+ disable_restatement : bool
1984+ effective_from : t .Optional [TimeLike ]
1985+ raw_snapshot : t .Dict [str , t .Any ]
1986+
1987+ @property
1988+ def snapshot_id (self ) -> SnapshotId :
1989+ return SnapshotId (name = self .name , identifier = self .identifier )
1990+
1991+ @property
1992+ def is_forward_only (self ) -> bool :
1993+ return self .change_category == SnapshotChangeCategory .FORWARD_ONLY
1994+
1995+ @property
1996+ def normalized_effective_from_ts (self ) -> t .Optional [int ]:
1997+ return (
1998+ to_timestamp (self .interval_unit .cron_floor (self .effective_from ))
1999+ if self .effective_from
2000+ else None
2001+ )
2002+
2003+ @property
2004+ def temp_version (self ) -> str :
2005+ return self .temp_version_ or self .fingerprint .to_version ()
2006+
2007+ @property
2008+ def full_snapshot (self ) -> Snapshot :
2009+ return Snapshot (
2010+ ** {
2011+ ** self .raw_snapshot ,
2012+ "updated_ts" : self .updated_ts ,
2013+ "unpaused_ts" : self .unpaused_ts ,
2014+ "unrestorable" : self .unrestorable ,
2015+ }
2016+ )
2017+
2018+ def set_unpaused_ts (self , unpaused_dt : t .Optional [TimeLike ]) -> None :
2019+ """Sets the timestamp for when this snapshot was unpaused.
2020+
2021+ Args:
2022+ unpaused_dt: The datetime object of when this snapshot was unpaused.
2023+ """
2024+ self .unpaused_ts = (
2025+ to_timestamp (self .interval_unit .cron_floor (unpaused_dt )) if unpaused_dt else None
2026+ )
2027+
2028+ @classmethod
2029+ def from_snapshot_record (
2030+ cls ,
2031+ * ,
2032+ name : str ,
2033+ identifier : str ,
2034+ version : str ,
2035+ updated_ts : int ,
2036+ unpaused_ts : t .Optional [int ],
2037+ unrestorable : bool ,
2038+ snapshot : str ,
2039+ ) -> SharedVersionSnapshot :
2040+ raw_snapshot = json .loads (snapshot )
2041+ raw_node = raw_snapshot ["node" ]
2042+ return SharedVersionSnapshot (
2043+ name = name ,
2044+ version = version ,
2045+ temp_version = raw_snapshot .get ("temp_version" ),
2046+ identifier = identifier ,
2047+ fingerprint = raw_snapshot ["fingerprint" ],
2048+ interval_unit = raw_node .get ("interval_unit" , IntervalUnit .from_cron (raw_node ["cron" ])),
2049+ change_category = raw_snapshot ["change_category" ],
2050+ updated_ts = updated_ts ,
2051+ unpaused_ts = unpaused_ts ,
2052+ unrestorable = unrestorable ,
2053+ disable_restatement = raw_node .get ("kind" , {}).get ("disable_restatement" , False ),
2054+ effective_from = raw_snapshot .get ("effective_from" ),
2055+ raw_snapshot = raw_snapshot ,
2056+ )
0 commit comments