1010
1111from sqlmesh .core .engine_adapter import EngineAdapter
1212from sqlmesh .core .state_sync .db .utils import (
13- snapshot_name_filter ,
1413 snapshot_name_version_filter ,
1514 snapshot_id_filter ,
1615 fetchone ,
3332 SnapshotChangeCategory ,
3433)
3534from sqlmesh .utils .migration import index_text_type , blob_text_type
36- from sqlmesh .utils .date import now_timestamp , TimeLike , to_timestamp
35+ from sqlmesh .utils .date import now_timestamp , TimeLike , now , to_timestamp
3736from sqlmesh .utils .pydantic import PydanticModel
3837from sqlmesh .utils import unique
3938
4039if t .TYPE_CHECKING :
4140 import pandas as pd
4241
42+ from sqlmesh .core .state_sync .db .interval import IntervalState
43+
4344
4445logger = logging .getLogger (__name__ )
4546
@@ -69,7 +70,6 @@ def __init__(
6970 "unpaused_ts" : exp .DataType .build ("bigint" ),
7071 "ttl_ms" : exp .DataType .build ("bigint" ),
7172 "unrestorable" : exp .DataType .build ("boolean" ),
72- "forward_only" : exp .DataType .build ("boolean" ),
7373 }
7474
7575 self ._auto_restatement_columns_to_types = {
@@ -112,48 +112,84 @@ def unpause_snapshots(
112112 self ,
113113 snapshots : t .Collection [SnapshotInfoLike ],
114114 unpaused_dt : TimeLike ,
115+ interval_state : IntervalState ,
115116 ) -> None :
116- unrestorable_snapshots_by_forward_only : t . Dict [ bool , t . List [ str ]] = defaultdict ( list )
117+ """Unpauses given snapshots while pausing all other snapshots that share the same version.
117118
118- for snapshot in snapshots :
119- # We need to mark all other snapshots that have opposite forward only status as unrestorable
120- unrestorable_snapshots_by_forward_only [not snapshot .is_forward_only ].append (
121- snapshot .name
122- )
119+ Args:
120+ snapshots: The snapshots to unpause.
121+ unpaused_dt: The timestamp to unpause the snapshots at.
122+ interval_state: The interval state to use to remove intervals when needed.
123+ """
124+ current_ts = now ()
123125
124- updated_ts = now_timestamp ()
125- unpaused_ts = to_timestamp (unpaused_dt )
126+ target_snapshot_ids = {s .snapshot_id for s in snapshots }
127+ same_version_snapshots = self ._get_snapshots_with_same_version (
128+ snapshots , lock_for_update = True
129+ )
130+ target_snapshots_by_version = {
131+ (s .name , s .version ): s
132+ for s in same_version_snapshots
133+ if s .snapshot_id in target_snapshot_ids
134+ }
126135
127- # Pause all snapshots with target names first
128- for where in snapshot_name_filter (
129- [s .name for s in snapshots ],
130- batch_size = self .SNAPSHOT_BATCH_SIZE ,
131- ):
132- self .engine_adapter .update_table (
133- self .snapshots_table ,
134- {"unpaused_ts" : None , "updated_ts" : updated_ts },
135- where = where ,
136- )
136+ unpaused_snapshots : t .Dict [int , t .List [SnapshotId ]] = defaultdict (list )
137+ paused_snapshots : t .List [SnapshotId ] = []
138+ unrestorable_snapshots : t .List [SnapshotId ] = []
139+
140+ for snapshot in same_version_snapshots :
141+ is_target_snapshot = snapshot .snapshot_id in target_snapshot_ids
142+ if is_target_snapshot and not snapshot .unpaused_ts :
143+ logger .info ("Unpausing snapshot %s" , snapshot .snapshot_id )
144+ snapshot .set_unpaused_ts (unpaused_dt )
145+ assert snapshot .unpaused_ts is not None
146+ unpaused_snapshots [snapshot .unpaused_ts ].append (snapshot .snapshot_id )
147+ elif not is_target_snapshot :
148+ target_snapshot = target_snapshots_by_version [(snapshot .name , snapshot .version )]
149+ if (
150+ target_snapshot .normalized_effective_from_ts
151+ and not target_snapshot .disable_restatement
152+ ):
153+ # Making sure that there are no overlapping intervals.
154+ effective_from_ts = target_snapshot .normalized_effective_from_ts
155+ logger .info (
156+ "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s" ,
157+ target_snapshot .effective_from ,
158+ snapshot .snapshot_id ,
159+ target_snapshot .snapshot_id ,
160+ )
161+ full_snapshot = snapshot .full_snapshot
162+ interval_state .remove_intervals (
163+ [
164+ (
165+ full_snapshot ,
166+ full_snapshot .get_removal_interval (effective_from_ts , current_ts ),
167+ )
168+ ]
169+ )
137170
138- # Now unpause the target snapshots
139- self ._update_snapshots (
140- [s .snapshot_id for s in snapshots ],
141- unpaused_ts = unpaused_ts ,
142- updated_ts = updated_ts ,
143- )
171+ if snapshot .unpaused_ts :
172+ logger .info ("Pausing snapshot %s" , snapshot .snapshot_id )
173+ snapshot .set_unpaused_ts (None )
174+ paused_snapshots .append (snapshot .snapshot_id )
144175
145- # Mark unrestorable snapshots
146- for forward_only , snapshot_names in unrestorable_snapshots_by_forward_only .items ():
147- forward_only_exp = exp .column ("forward_only" ).is_ (exp .convert (forward_only ))
148- for where in snapshot_name_filter (
149- snapshot_names ,
150- batch_size = self .SNAPSHOT_BATCH_SIZE ,
151- ):
152- self .engine_adapter .update_table (
153- self .snapshots_table ,
154- {"unrestorable" : True , "updated_ts" : updated_ts },
155- where = forward_only_exp .and_ (where ),
156- )
176+ if not snapshot .unrestorable and (
177+ (target_snapshot .is_forward_only and not snapshot .is_forward_only )
178+ or (snapshot .is_forward_only and not target_snapshot .is_forward_only )
179+ ):
180+ logger .info ("Marking snapshot %s as unrestorable" , snapshot .snapshot_id )
181+ snapshot .unrestorable = True
182+ unrestorable_snapshots .append (snapshot .snapshot_id )
183+
184+ if unpaused_snapshots :
185+ for unpaused_ts , snapshot_ids in unpaused_snapshots .items ():
186+ self ._update_snapshots (snapshot_ids , unpaused_ts = unpaused_ts )
187+
188+ if paused_snapshots :
189+ self ._update_snapshots (paused_snapshots , unpaused_ts = None )
190+
191+ if unrestorable_snapshots :
192+ self ._update_snapshots (unrestorable_snapshots , unrestorable = True )
157193
158194 def get_expired_snapshots (
159195 self ,
@@ -378,8 +414,7 @@ def _update_snapshots(
378414 ** kwargs : t .Any ,
379415 ) -> None :
380416 properties = kwargs
381- if "updated_ts" not in properties :
382- properties ["updated_ts" ] = now_timestamp ()
417+ properties ["updated_ts" ] = now_timestamp ()
383418
384419 for where in snapshot_id_filter (
385420 self .engine_adapter , snapshots , batch_size = self .SNAPSHOT_BATCH_SIZE
@@ -431,15 +466,13 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
431466 updated_ts ,
432467 unpaused_ts ,
433468 unrestorable ,
434- forward_only ,
435469 next_auto_restatement_ts ,
436470 ) in fetchall (self .engine_adapter , query ):
437471 snapshot = parse_snapshot (
438472 serialized_snapshot = serialized_snapshot ,
439473 updated_ts = updated_ts ,
440474 unpaused_ts = unpaused_ts ,
441475 unrestorable = unrestorable ,
442- forward_only = forward_only ,
443476 next_auto_restatement_ts = next_auto_restatement_ts ,
444477 )
445478 snapshot_id = snapshot .snapshot_id
@@ -469,7 +502,6 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
469502 "updated_ts" ,
470503 "unpaused_ts" ,
471504 "unrestorable" ,
472- "forward_only" ,
473505 "next_auto_restatement_ts" ,
474506 )
475507 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
@@ -496,15 +528,13 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
496528 updated_ts ,
497529 unpaused_ts ,
498530 unrestorable ,
499- forward_only ,
500531 next_auto_restatement_ts ,
501532 ) in fetchall (self .engine_adapter , query ):
502533 snapshot_id = SnapshotId (name = name , identifier = identifier )
503534 snapshot = snapshots [snapshot_id ]
504535 snapshot .updated_ts = updated_ts
505536 snapshot .unpaused_ts = unpaused_ts
506537 snapshot .unrestorable = unrestorable
507- snapshot .forward_only = forward_only
508538 snapshot .next_auto_restatement_ts = next_auto_restatement_ts
509539 cached_snapshots_in_state .add (snapshot_id )
510540
@@ -538,7 +568,6 @@ def _get_snapshots_expressions(
538568 "snapshots.updated_ts" ,
539569 "snapshots.unpaused_ts" ,
540570 "snapshots.unrestorable" ,
541- "snapshots.forward_only" ,
542571 "auto_restatements.next_auto_restatement_ts" ,
543572 )
544573 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
@@ -594,7 +623,6 @@ def _get_snapshots_with_same_version(
594623 "updated_ts" ,
595624 "unpaused_ts" ,
596625 "unrestorable" ,
597- "forward_only" ,
598626 )
599627 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
600628 .where (where )
@@ -612,10 +640,9 @@ def _get_snapshots_with_same_version(
612640 updated_ts = updated_ts ,
613641 unpaused_ts = unpaused_ts ,
614642 unrestorable = unrestorable ,
615- forward_only = forward_only ,
616643 snapshot = snapshot ,
617644 )
618- for snapshot , name , identifier , version , updated_ts , unpaused_ts , unrestorable , forward_only in snapshot_rows
645+ for snapshot , name , identifier , version , updated_ts , unpaused_ts , unrestorable in snapshot_rows
619646 ]
620647
621648
@@ -624,7 +651,6 @@ def parse_snapshot(
624651 updated_ts : int ,
625652 unpaused_ts : t .Optional [int ],
626653 unrestorable : bool ,
627- forward_only : bool ,
628654 next_auto_restatement_ts : t .Optional [int ],
629655) -> Snapshot :
630656 return Snapshot (
@@ -633,7 +659,6 @@ def parse_snapshot(
633659 "updated_ts" : updated_ts ,
634660 "unpaused_ts" : unpaused_ts ,
635661 "unrestorable" : unrestorable ,
636- "forward_only" : forward_only ,
637662 "next_auto_restatement_ts" : next_auto_restatement_ts ,
638663 }
639664 )
@@ -648,7 +673,6 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
648673 "updated_ts" ,
649674 "unpaused_ts" ,
650675 "unrestorable" ,
651- "forward_only" ,
652676 "next_auto_restatement_ts" ,
653677 }
654678 )
@@ -669,7 +693,6 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
669693 "unpaused_ts" : snapshot .unpaused_ts ,
670694 "ttl_ms" : snapshot .ttl_ms ,
671695 "unrestorable" : snapshot .unrestorable ,
672- "forward_only" : snapshot .forward_only ,
673696 }
674697 for snapshot in snapshots
675698 ]
@@ -739,10 +762,19 @@ def full_snapshot(self) -> Snapshot:
739762 "updated_ts" : self .updated_ts ,
740763 "unpaused_ts" : self .unpaused_ts ,
741764 "unrestorable" : self .unrestorable ,
742- "forward_only" : self .forward_only ,
743765 }
744766 )
745767
768+ def set_unpaused_ts (self , unpaused_dt : t .Optional [TimeLike ]) -> None :
769+ """Sets the timestamp for when this snapshot was unpaused.
770+
771+ Args:
772+ unpaused_dt: The datetime object of when this snapshot was unpaused.
773+ """
774+ self .unpaused_ts = (
775+ to_timestamp (self .interval_unit .cron_floor (unpaused_dt )) if unpaused_dt else None
776+ )
777+
746778 @classmethod
747779 def from_snapshot_record (
748780 cls ,
@@ -753,7 +785,6 @@ def from_snapshot_record(
753785 updated_ts : int ,
754786 unpaused_ts : t .Optional [int ],
755787 unrestorable : bool ,
756- forward_only : bool ,
757788 snapshot : str ,
758789 ) -> SharedVersionSnapshot :
759790 raw_snapshot = json .loads (snapshot )
@@ -772,5 +803,5 @@ def from_snapshot_record(
772803 disable_restatement = raw_node .get ("kind" , {}).get ("disable_restatement" , False ),
773804 effective_from = raw_snapshot .get ("effective_from" ),
774805 raw_snapshot = raw_snapshot ,
775- forward_only = forward_only ,
806+ forward_only = raw_snapshot . get ( " forward_only" , False ) ,
776807 )
0 commit comments