5959from sqlmesh .core .state_sync .base import MIGRATIONS , SCHEMA_VERSION , StateSync , Versions
6060from sqlmesh .core .state_sync .common import CommonStateSyncMixin , transactional
6161from sqlmesh .utils import major_minor , random_id , unique
62+ from sqlmesh .utils .dag import DAG
6263from sqlmesh .utils .date import TimeLike , now_timestamp , time_like_to_str
6364from sqlmesh .utils .errors import SQLMeshError
6465from sqlmesh .utils .pydantic import parse_obj_as
@@ -899,7 +900,12 @@ def _restore_table(
899900 )
900901
901902 @transactional ()
902- def migrate (self , default_catalog : t .Optional [str ], skip_backup : bool = False ) -> None :
903+ def migrate (
904+ self ,
905+ default_catalog : t .Optional [str ],
906+ skip_backup : bool = False ,
907+ promoted_snapshots_only : bool = True ,
908+ ) -> None :
903909 """Migrate the state sync to the latest SQLMesh / SQLGlot version."""
904910 versions = self .get_versions (validate = False )
905911 migrations = MIGRATIONS [versions .schema_version :]
@@ -919,7 +925,7 @@ def migrate(self, default_catalog: t.Optional[str], skip_backup: bool = False) -
919925 logger .info (f"Applying migration { migration } " )
920926 migration .migrate (self , default_catalog = default_catalog )
921927
922- self ._migrate_rows ()
928+ self ._migrate_rows (promoted_snapshots_only )
923929 self ._update_versions ()
924930 except Exception as e :
925931 if skip_backup :
@@ -974,117 +980,55 @@ def _backup_state(self) -> None:
974980 backup_name , exp .select ("*" ).from_ (table ), exists = False
975981 )
976982
977- def _migrate_rows (self ) -> None :
978- snapshot_mapping = self ._migrate_snapshot_rows ()
983+ def _migrate_rows (self , promoted_snapshots_only : bool ) -> None :
984+ logger .info ("Fetching environments" )
985+ environments = self .get_environments ()
986+ # Only migrate snapshots that are part of at least one environment.
987+ snapshots_to_migrate = (
988+ {s .snapshot_id for e in environments for s in e .snapshots }
989+ if promoted_snapshots_only
990+ else None
991+ )
992+ snapshot_mapping = self ._migrate_snapshot_rows (snapshots_to_migrate )
979993 if not snapshot_mapping :
980- logger .debug ("No changes to snapshots detected" )
994+ logger .info ("No changes to snapshots detected" )
981995 return
982996 self ._migrate_seed_rows (snapshot_mapping )
983- self ._migrate_environment_rows (snapshot_mapping )
997+ self ._migrate_environment_rows (environments , snapshot_mapping )
984998
985- def _migrate_snapshot_rows (self ) -> t .Dict [SnapshotId , SnapshotTableInfo ]:
999+ def _migrate_snapshot_rows (
1000+ self , snapshots : t .Optional [t .Set [SnapshotId ]]
1001+ ) -> t .Dict [SnapshotId , SnapshotTableInfo ]:
9861002 logger .info ("Migrating snapshot rows..." )
9871003 raw_snapshots = {
988- SnapshotId (name = name , identifier = identifier ): raw_snapshot
1004+ SnapshotId (name = name , identifier = identifier ): json .loads (raw_snapshot )
1005+ for where in (self ._snapshot_id_filter (snapshots ) if snapshots is not None else [None ])
9891006 for name , identifier , raw_snapshot in self ._fetchall (
990- exp .select ("name" , "identifier" , "snapshot" ).from_ (self .snapshots_table ).lock ()
1007+ exp .select ("name" , "identifier" , "snapshot" )
1008+ .from_ (self .snapshots_table )
1009+ .where (where )
1010+ .lock ()
9911011 )
9921012 }
9931013 if not raw_snapshots :
9941014 return {}
9951015
996- self .console .start_migration_progress (len (raw_snapshots ))
997-
998- all_snapshot_mapping : t .Dict [SnapshotId , SnapshotTableInfo ] = {}
999-
1000- for snapshot_id_batch in self ._snapshot_batches (
1001- sorted (raw_snapshots ), batch_size = self .SNAPSHOT_MIGRATION_BATCH_SIZE
1002- ):
1003- parsed_snapshots = LazilyParsedSnapshots (raw_snapshots )
1004- snapshot_id_mapping : t .Dict [SnapshotId , SnapshotId ] = {}
1005- new_snapshots : t .Dict [SnapshotId , Snapshot ] = {}
1006-
1007- for snapshot_id in snapshot_id_batch :
1008- snapshot = parsed_snapshots [snapshot_id ]
1009-
1010- seen = set ()
1011- queue = {snapshot .snapshot_id }
1012- node = snapshot .node
1013- nodes : t .Dict [str , Node ] = {}
1014- audits : t .Dict [str , ModelAudit ] = {}
1015-
1016- while queue :
1017- snapshot_id = queue .pop ()
1018-
1019- if snapshot_id in seen :
1020- continue
1021-
1022- seen .add (snapshot_id )
1023-
1024- s = parsed_snapshots .get (snapshot_id )
1025-
1026- if not s :
1027- continue
1028-
1029- queue .update (s .parents )
1030- nodes [s .name ] = s .node
1031- for audit in s .audits :
1032- audits [audit .name ] = audit
1033-
1034- new_snapshot = deepcopy (snapshot )
1035-
1036- fingerprint_cache : t .Dict [str , SnapshotFingerprint ] = {}
1037-
1038- try :
1039- new_snapshot .fingerprint = fingerprint_from_node (
1040- node ,
1041- nodes = nodes ,
1042- audits = audits ,
1043- )
1044- new_snapshot .parents = tuple (
1045- SnapshotId (
1046- name = parent_node .fqn ,
1047- identifier = fingerprint_from_node (
1048- parent_node ,
1049- nodes = nodes ,
1050- audits = audits ,
1051- cache = fingerprint_cache ,
1052- ).to_identifier (),
1053- )
1054- for parent_node in _parents_from_node (node , nodes ).values ()
1055- )
1056- except Exception :
1057- logger .exception ("Could not compute fingerprint for %s" , snapshot .snapshot_id )
1058- continue
1016+ dag : DAG [SnapshotId ] = DAG ()
1017+ for snapshot_id , raw_snapshot in raw_snapshots .items ():
1018+ parent_ids = [SnapshotId .parse_obj (p_id ) for p_id in raw_snapshot .get ("parents" , [])]
1019+ dag .add (snapshot_id , [p_id for p_id in parent_ids if p_id in raw_snapshots ])
10591020
1060- new_snapshot .previous_versions = snapshot .all_versions
1061- new_snapshot .migrated = True
1062- if not new_snapshot .temp_version :
1063- new_snapshot .temp_version = snapshot .fingerprint .to_version ()
1021+ reversed_dag_raw = dag .reversed .graph
10641022
1065- self .console .update_migration_progress (1 )
1066-
1067- if new_snapshot .fingerprint == snapshot .fingerprint :
1068- logger .debug (f"{ new_snapshot .snapshot_id } is unchanged." )
1069- continue
1070-
1071- new_snapshot_id = new_snapshot .snapshot_id
1072-
1073- if new_snapshot_id in raw_snapshots :
1074- # Mapped to an existing snapshot.
1075- new_snapshots [new_snapshot_id ] = parsed_snapshots [new_snapshot_id ]
1076- elif (
1077- new_snapshot_id not in new_snapshots
1078- or new_snapshot .updated_ts > new_snapshots [new_snapshot_id ].updated_ts
1079- ):
1080- new_snapshots [new_snapshot_id ] = new_snapshot
1081-
1082- snapshot_id_mapping [snapshot .snapshot_id ] = new_snapshot_id
1083- logger .debug (f"{ snapshot .snapshot_id } mapped to { new_snapshot_id } ." )
1023+ self .console .start_migration_progress (len (raw_snapshots ))
10841024
1085- if not new_snapshots :
1086- continue
1025+ parsed_snapshots = LazilyParsedSnapshots (raw_snapshots )
1026+ all_snapshot_mapping : t .Dict [SnapshotId , SnapshotTableInfo ] = {}
1027+ snapshot_id_mapping : t .Dict [SnapshotId , SnapshotId ] = {}
1028+ new_snapshots : t .Dict [SnapshotId , Snapshot ] = {}
1029+ visited : t .Set [SnapshotId ] = set ()
10871030
1031+ def _push_new_snapshots () -> None :
10881032 all_snapshot_mapping .update (
10891033 {
10901034 from_id : new_snapshots [to_id ].table_info
@@ -1097,10 +1041,102 @@ def _migrate_snapshot_rows(self) -> t.Dict[SnapshotId, SnapshotTableInfo]:
10971041 s for s in new_snapshots .values () if s .snapshot_id not in existing_new_snapshots
10981042 ]
10991043 if new_snapshots_to_push :
1044+ logger .info ("Pushing %s migrated snapshots" , len (new_snapshots_to_push ))
11001045 self ._push_snapshots (new_snapshots_to_push )
1046+ new_snapshots .clear ()
1047+ snapshot_id_mapping .clear ()
1048+
1049+ def _visit (
1050+ snapshot_id : SnapshotId , fingerprint_cache : t .Dict [str , SnapshotFingerprint ]
1051+ ) -> None :
1052+ if snapshot_id in visited or snapshot_id not in raw_snapshots :
1053+ return
1054+ visited .add (snapshot_id )
1055+
1056+ snapshot = parsed_snapshots [snapshot_id ]
1057+ node = snapshot .node
1058+
1059+ node_seen = set ()
1060+ node_queue = {snapshot_id }
1061+ nodes : t .Dict [str , Node ] = {}
1062+ audits : t .Dict [str , ModelAudit ] = {}
1063+ while node_queue :
1064+ next_snapshot_id = node_queue .pop ()
1065+ next_snapshot = parsed_snapshots .get (next_snapshot_id )
1066+
1067+ if next_snapshot_id in node_seen or not next_snapshot :
1068+ continue
1069+
1070+ node_seen .add (next_snapshot_id )
1071+ node_queue .update (next_snapshot .parents )
1072+
1073+ nodes [next_snapshot .name ] = next_snapshot .node
1074+ audits .update ({a .name : a for a in next_snapshot .audits })
1075+
1076+ new_snapshot = deepcopy (snapshot )
1077+ try :
1078+ new_snapshot .fingerprint = fingerprint_from_node (
1079+ node ,
1080+ nodes = nodes ,
1081+ audits = audits ,
1082+ cache = fingerprint_cache ,
1083+ )
1084+ new_snapshot .parents = tuple (
1085+ SnapshotId (
1086+ name = parent_node .fqn ,
1087+ identifier = fingerprint_from_node (
1088+ parent_node ,
1089+ nodes = nodes ,
1090+ audits = audits ,
1091+ cache = fingerprint_cache ,
1092+ ).to_identifier (),
1093+ )
1094+ for parent_node in _parents_from_node (node , nodes ).values ()
1095+ )
1096+ except Exception :
1097+ logger .exception ("Could not compute fingerprint for %s" , snapshot .snapshot_id )
1098+ return
1099+
1100+ new_snapshot .previous_versions = snapshot .all_versions
1101+ new_snapshot .migrated = True
1102+ if not new_snapshot .temp_version :
1103+ new_snapshot .temp_version = snapshot .fingerprint .to_version ()
1104+
1105+ self .console .update_migration_progress (1 )
1106+
1107+ # Visit children and evict them from the parsed_snapshots cache after.
1108+ for child in reversed_dag_raw .get (snapshot_id , []):
1109+ # Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name.
1110+ _visit (child , fingerprint_cache .copy ())
1111+ parsed_snapshots .evict (child )
1112+
1113+ if new_snapshot .fingerprint == snapshot .fingerprint :
1114+ logger .debug (f"{ new_snapshot .snapshot_id } is unchanged." )
1115+ return
1116+
1117+ new_snapshot_id = new_snapshot .snapshot_id
1118+
1119+ if new_snapshot_id in raw_snapshots :
1120+ # Mapped to an existing snapshot.
1121+ new_snapshots [new_snapshot_id ] = parsed_snapshots [new_snapshot_id ]
1122+ logger .debug ("Migrated snapshot %s already exists" , new_snapshot_id )
1123+ elif (
1124+ new_snapshot_id not in new_snapshots
1125+ or new_snapshot .updated_ts > new_snapshots [new_snapshot_id ].updated_ts
1126+ ):
1127+ new_snapshots [new_snapshot_id ] = new_snapshot
11011128
1102- # Force cleanup to free memory.
1103- del parsed_snapshots
1129+ snapshot_id_mapping [snapshot .snapshot_id ] = new_snapshot_id
1130+ logger .debug ("%s mapped to %s" , snapshot .snapshot_id , new_snapshot_id )
1131+
1132+ if len (new_snapshots ) >= self .SNAPSHOT_MIGRATION_BATCH_SIZE :
1133+ _push_new_snapshots ()
1134+
1135+ for root_snapshot_id in dag .roots :
1136+ _visit (root_snapshot_id , {})
1137+
1138+ if new_snapshots :
1139+ _push_new_snapshots ()
11041140
11051141 return all_snapshot_mapping
11061142
@@ -1146,10 +1182,11 @@ def _migrate_seed_rows(self, snapshot_mapping: t.Dict[SnapshotId, SnapshotTableI
11461182 )
11471183
11481184 def _migrate_environment_rows (
1149- self , snapshot_mapping : t .Dict [SnapshotId , SnapshotTableInfo ]
1185+ self ,
1186+ environments : t .List [Environment ],
1187+ snapshot_mapping : t .Dict [SnapshotId , SnapshotTableInfo ],
11501188 ) -> None :
11511189 logger .info ("Migrating environment rows..." )
1152- environments = self .get_environments ()
11531190
11541191 updated_prod_environment : t .Optional [Environment ] = None
11551192 updated_environments = []
@@ -1355,19 +1392,22 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
13551392
13561393
13571394class LazilyParsedSnapshots :
1358- def __init__ (self , raw_snapshots : t .Dict [SnapshotId , str ]):
1395+ def __init__ (self , raw_snapshots : t .Dict [SnapshotId , t . Dict [ str , t . Any ] ]):
13591396 self ._raw_snapshots = raw_snapshots
13601397 self ._parsed_snapshots : t .Dict [SnapshotId , t .Optional [Snapshot ]] = {}
13611398
13621399 def get (self , snapshot_id : SnapshotId ) -> t .Optional [Snapshot ]:
13631400 if snapshot_id not in self ._parsed_snapshots :
13641401 raw_snapshot = self ._raw_snapshots .get (snapshot_id )
13651402 if raw_snapshot :
1366- self ._parsed_snapshots [snapshot_id ] = Snapshot .parse_raw (raw_snapshot )
1403+ self ._parsed_snapshots [snapshot_id ] = Snapshot .parse_obj (raw_snapshot )
13671404 else :
13681405 self ._parsed_snapshots [snapshot_id ] = None
13691406 return self ._parsed_snapshots [snapshot_id ]
13701407
1408+ def evict (self , snapshot_id : SnapshotId ) -> None :
1409+ self ._parsed_snapshots .pop (snapshot_id , None )
1410+
13711411 def __getitem__ (self , snapshot_id : SnapshotId ) -> Snapshot :
13721412 snapshot = self .get (snapshot_id )
13731413 if snapshot is None :
0 commit comments