1010
1111from sqlmesh .core .engine_adapter import EngineAdapter
1212from sqlmesh .core .state_sync .db .utils import (
13+ snapshot_name_filter ,
1314 snapshot_name_version_filter ,
1415 snapshot_id_filter ,
1516 fetchone ,
3233 SnapshotChangeCategory ,
3334)
3435from sqlmesh .utils .migration import index_text_type , blob_text_type
35- from sqlmesh .utils .date import now_timestamp , TimeLike , now , to_timestamp
36+ from sqlmesh .utils .date import now_timestamp , TimeLike , to_timestamp
3637from sqlmesh .utils .pydantic import PydanticModel
3738from sqlmesh .utils import unique
3839
3940if t .TYPE_CHECKING :
4041 import pandas as pd
4142
42- from sqlmesh .core .state_sync .db .interval import IntervalState
43-
4443
4544logger = logging .getLogger (__name__ )
4645
@@ -70,6 +69,7 @@ def __init__(
7069 "unpaused_ts" : exp .DataType .build ("bigint" ),
7170 "ttl_ms" : exp .DataType .build ("bigint" ),
7271 "unrestorable" : exp .DataType .build ("boolean" ),
72+ "forward_only" : exp .DataType .build ("boolean" ),
7373 }
7474
7575 self ._auto_restatement_columns_to_types = {
@@ -112,84 +112,52 @@ def unpause_snapshots(
112112 self ,
113113 snapshots : t .Collection [SnapshotInfoLike ],
114114 unpaused_dt : TimeLike ,
115- interval_state : IntervalState ,
116115 ) -> None :
117- """Unpauses given snapshots while pausing all other snapshots that share the same version.
118-
119- Args:
120- snapshots: The snapshots to unpause.
121- unpaused_dt: The timestamp to unpause the snapshots at.
122- interval_state: The interval state to use to remove intervals when needed.
123- """
124- current_ts = now ()
125-
126- target_snapshot_ids = {s .snapshot_id for s in snapshots }
127- same_version_snapshots = self ._get_snapshots_with_same_version (
128- snapshots , lock_for_update = True
116+ unrestorable_snapshots_by_forward_only : t .Dict [bool , t .List [SnapshotNameVersion ]] = (
117+ defaultdict (list )
129118 )
130- target_snapshots_by_version = {
131- (s .name , s .version ): s
132- for s in same_version_snapshots
133- if s .snapshot_id in target_snapshot_ids
134- }
135-
136- unpaused_snapshots : t .Dict [int , t .List [SnapshotId ]] = defaultdict (list )
137- paused_snapshots : t .List [SnapshotId ] = []
138- unrestorable_snapshots : t .List [SnapshotId ] = []
139-
140- for snapshot in same_version_snapshots :
141- is_target_snapshot = snapshot .snapshot_id in target_snapshot_ids
142- if is_target_snapshot and not snapshot .unpaused_ts :
143- logger .info ("Unpausing snapshot %s" , snapshot .snapshot_id )
144- snapshot .set_unpaused_ts (unpaused_dt )
145- assert snapshot .unpaused_ts is not None
146- unpaused_snapshots [snapshot .unpaused_ts ].append (snapshot .snapshot_id )
147- elif not is_target_snapshot :
148- target_snapshot = target_snapshots_by_version [(snapshot .name , snapshot .version )]
149- if (
150- target_snapshot .normalized_effective_from_ts
151- and not target_snapshot .disable_restatement
152- ):
153- # Making sure that there are no overlapping intervals.
154- effective_from_ts = target_snapshot .normalized_effective_from_ts
155- logger .info (
156- "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s" ,
157- target_snapshot .effective_from ,
158- snapshot .snapshot_id ,
159- target_snapshot .snapshot_id ,
160- )
161- full_snapshot = snapshot .full_snapshot
162- interval_state .remove_intervals (
163- [
164- (
165- full_snapshot ,
166- full_snapshot .get_removal_interval (effective_from_ts , current_ts ),
167- )
168- ]
169- )
170119
171- if snapshot .unpaused_ts :
172- logger .info ("Pausing snapshot %s" , snapshot .snapshot_id )
173- snapshot .set_unpaused_ts (None )
174- paused_snapshots .append (snapshot .snapshot_id )
120+ for snapshot in snapshots :
121+ # We need to mark all other snapshots that have forward-only opposite to the target snapshot as unrestorable
122+ unrestorable_snapshots_by_forward_only [not snapshot .is_forward_only ].append (
123+ snapshot .name_version
124+ )
175125
176- if not snapshot .unrestorable and (
177- (target_snapshot .is_forward_only and not snapshot .is_forward_only )
178- or (snapshot .is_forward_only and not target_snapshot .is_forward_only )
179- ):
180- logger .info ("Marking snapshot %s as unrestorable" , snapshot .snapshot_id )
181- snapshot .unrestorable = True
182- unrestorable_snapshots .append (snapshot .snapshot_id )
126+ updated_ts = now_timestamp ()
127+ unpaused_ts = to_timestamp (unpaused_dt )
183128
184- if unpaused_snapshots :
185- for unpaused_ts , snapshot_ids in unpaused_snapshots .items ():
186- self ._update_snapshots (snapshot_ids , unpaused_ts = unpaused_ts )
129+ # Pause all snapshots with target names first
130+ for where in snapshot_name_filter (
131+ [s .name for s in snapshots ],
132+ batch_size = self .SNAPSHOT_BATCH_SIZE ,
133+ ):
134+ self .engine_adapter .update_table (
135+ self .snapshots_table ,
136+ {"unpaused_ts" : None , "updated_ts" : updated_ts },
137+ where = where ,
138+ )
187139
188- if paused_snapshots :
189- self ._update_snapshots (paused_snapshots , unpaused_ts = None )
140+ # Now unpause the target snapshots
141+ self ._update_snapshots (
142+ [s .snapshot_id for s in snapshots ],
143+ unpaused_ts = unpaused_ts ,
144+ updated_ts = updated_ts ,
145+ )
190146
191- if unrestorable_snapshots :
192- self ._update_snapshots (unrestorable_snapshots , unrestorable = True )
147+ # Mark unrestorable snapshots
148+ for forward_only , snapshot_name_versions in unrestorable_snapshots_by_forward_only .items ():
149+ forward_only_exp = exp .column ("forward_only" ).is_ (exp .convert (forward_only ))
150+ for where in snapshot_name_version_filter (
151+ self .engine_adapter ,
152+ snapshot_name_versions ,
153+ batch_size = self .SNAPSHOT_BATCH_SIZE ,
154+ alias = None ,
155+ ):
156+ self .engine_adapter .update_table (
157+ self .snapshots_table ,
158+ {"unrestorable" : True , "updated_ts" : updated_ts },
159+ where = forward_only_exp .and_ (where ),
160+ )
193161
194162 def get_expired_snapshots (
195163 self ,
@@ -414,7 +382,8 @@ def _update_snapshots(
414382 ** kwargs : t .Any ,
415383 ) -> None :
416384 properties = kwargs
417- properties ["updated_ts" ] = now_timestamp ()
385+ if "updated_ts" not in properties :
386+ properties ["updated_ts" ] = now_timestamp ()
418387
419388 for where in snapshot_id_filter (
420389 self .engine_adapter , snapshots , batch_size = self .SNAPSHOT_BATCH_SIZE
@@ -466,13 +435,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
466435 updated_ts ,
467436 unpaused_ts ,
468437 unrestorable ,
438+ forward_only ,
469439 next_auto_restatement_ts ,
470440 ) in fetchall (self .engine_adapter , query ):
471441 snapshot = parse_snapshot (
472442 serialized_snapshot = serialized_snapshot ,
473443 updated_ts = updated_ts ,
474444 unpaused_ts = unpaused_ts ,
475445 unrestorable = unrestorable ,
446+ forward_only = forward_only ,
476447 next_auto_restatement_ts = next_auto_restatement_ts ,
477448 )
478449 snapshot_id = snapshot .snapshot_id
@@ -502,6 +473,7 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
502473 "updated_ts" ,
503474 "unpaused_ts" ,
504475 "unrestorable" ,
476+ "forward_only" ,
505477 "next_auto_restatement_ts" ,
506478 )
507479 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
@@ -528,13 +500,15 @@ def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]:
528500 updated_ts ,
529501 unpaused_ts ,
530502 unrestorable ,
503+ forward_only ,
531504 next_auto_restatement_ts ,
532505 ) in fetchall (self .engine_adapter , query ):
533506 snapshot_id = SnapshotId (name = name , identifier = identifier )
534507 snapshot = snapshots [snapshot_id ]
535508 snapshot .updated_ts = updated_ts
536509 snapshot .unpaused_ts = unpaused_ts
537510 snapshot .unrestorable = unrestorable
511+ snapshot .forward_only = forward_only
538512 snapshot .next_auto_restatement_ts = next_auto_restatement_ts
539513 cached_snapshots_in_state .add (snapshot_id )
540514
@@ -568,6 +542,7 @@ def _get_snapshots_expressions(
568542 "snapshots.updated_ts" ,
569543 "snapshots.unpaused_ts" ,
570544 "snapshots.unrestorable" ,
545+ "snapshots.forward_only" ,
571546 "auto_restatements.next_auto_restatement_ts" ,
572547 )
573548 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
@@ -623,6 +598,7 @@ def _get_snapshots_with_same_version(
623598 "updated_ts" ,
624599 "unpaused_ts" ,
625600 "unrestorable" ,
601+ "forward_only" ,
626602 )
627603 .from_ (exp .to_table (self .snapshots_table ).as_ ("snapshots" ))
628604 .where (where )
@@ -640,9 +616,10 @@ def _get_snapshots_with_same_version(
640616 updated_ts = updated_ts ,
641617 unpaused_ts = unpaused_ts ,
642618 unrestorable = unrestorable ,
619+ forward_only = forward_only ,
643620 snapshot = snapshot ,
644621 )
645- for snapshot , name , identifier , version , updated_ts , unpaused_ts , unrestorable in snapshot_rows
622+ for snapshot , name , identifier , version , updated_ts , unpaused_ts , unrestorable , forward_only in snapshot_rows
646623 ]
647624
648625
@@ -651,6 +628,7 @@ def parse_snapshot(
651628 updated_ts : int ,
652629 unpaused_ts : t .Optional [int ],
653630 unrestorable : bool ,
631+ forward_only : bool ,
654632 next_auto_restatement_ts : t .Optional [int ],
655633) -> Snapshot :
656634 return Snapshot (
@@ -659,6 +637,7 @@ def parse_snapshot(
659637 "updated_ts" : updated_ts ,
660638 "unpaused_ts" : unpaused_ts ,
661639 "unrestorable" : unrestorable ,
640+ "forward_only" : forward_only ,
662641 "next_auto_restatement_ts" : next_auto_restatement_ts ,
663642 }
664643 )
@@ -673,6 +652,7 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
673652 "updated_ts" ,
674653 "unpaused_ts" ,
675654 "unrestorable" ,
655+ "forward_only" ,
676656 "next_auto_restatement_ts" ,
677657 }
678658 )
@@ -693,6 +673,7 @@ def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame:
693673 "unpaused_ts" : snapshot .unpaused_ts ,
694674 "ttl_ms" : snapshot .ttl_ms ,
695675 "unrestorable" : snapshot .unrestorable ,
676+ "forward_only" : snapshot .forward_only ,
696677 }
697678 for snapshot in snapshots
698679 ]
@@ -762,19 +743,10 @@ def full_snapshot(self) -> Snapshot:
762743 "updated_ts" : self .updated_ts ,
763744 "unpaused_ts" : self .unpaused_ts ,
764745 "unrestorable" : self .unrestorable ,
746+ "forward_only" : self .forward_only ,
765747 }
766748 )
767749
768- def set_unpaused_ts (self , unpaused_dt : t .Optional [TimeLike ]) -> None :
769- """Sets the timestamp for when this snapshot was unpaused.
770-
771- Args:
772- unpaused_dt: The datetime object of when this snapshot was unpaused.
773- """
774- self .unpaused_ts = (
775- to_timestamp (self .interval_unit .cron_floor (unpaused_dt )) if unpaused_dt else None
776- )
777-
778750 @classmethod
779751 def from_snapshot_record (
780752 cls ,
@@ -785,6 +757,7 @@ def from_snapshot_record(
785757 updated_ts : int ,
786758 unpaused_ts : t .Optional [int ],
787759 unrestorable : bool ,
760+ forward_only : bool ,
788761 snapshot : str ,
789762 ) -> SharedVersionSnapshot :
790763 raw_snapshot = json .loads (snapshot )
@@ -803,5 +776,5 @@ def from_snapshot_record(
803776 disable_restatement = raw_node .get ("kind" , {}).get ("disable_restatement" , False ),
804777 effective_from = raw_snapshot .get ("effective_from" ),
805778 raw_snapshot = raw_snapshot ,
806- forward_only = raw_snapshot . get ( " forward_only" , False ) ,
779+ forward_only = forward_only ,
807780 )
0 commit comments