Skip to content

Commit 446e150

Browse files
authored
Fix: Improve performance of migration (#2263)
1 parent 16ff411 commit 446e150

File tree

7 files changed

+178
-106
lines changed

7 files changed

+178
-106
lines changed

sqlmesh/core/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
load_config_from_yaml,
2222
load_configs,
2323
)
24+
from sqlmesh.core.config.migration import MigrationConfig
2425
from sqlmesh.core.config.model import ModelDefaultsConfig
2526
from sqlmesh.core.config.plan import PlanConfig
2627
from sqlmesh.core.config.root import Config

sqlmesh/core/config/migration.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from __future__ import annotations
2+
3+
from sqlmesh.core.config.base import BaseConfig
4+
5+
6+
class MigrationConfig(BaseConfig):
7+
"""Configuration for the SQLMesh state migration.
8+
9+
Args:
10+
promoted_snapshots_only: If True, only snapshots that are part of at least one environment will be migrated.
11+
Otherwise, all snapshots will be migrated.
12+
"""
13+
14+
promoted_snapshots_only: bool = True

sqlmesh/core/config/root.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sqlmesh.core.config.feature_flag import FeatureFlag
2424
from sqlmesh.core.config.format import FormatConfig
2525
from sqlmesh.core.config.gateway import GatewayConfig
26+
from sqlmesh.core.config.migration import MigrationConfig
2627
from sqlmesh.core.config.model import ModelDefaultsConfig
2728
from sqlmesh.core.config.plan import PlanConfig
2829
from sqlmesh.core.config.run import RunConfig
@@ -69,6 +70,8 @@ class Config(BaseConfig):
6970
format: The formatting options for SQL code.
7071
ui: The UI configuration for SQLMesh.
7172
feature_flags: Feature flags to enable/disable certain features.
73+
plan: The plan configuration.
74+
migration: The migration configuration.
7275
"""
7376

7477
gateways: t.Dict[str, GatewayConfig] = {"": GatewayConfig()}
@@ -104,6 +107,7 @@ class Config(BaseConfig):
104107
ui: UIConfig = UIConfig()
105108
feature_flags: FeatureFlag = FeatureFlag()
106109
plan: PlanConfig = PlanConfig()
110+
migration: MigrationConfig = MigrationConfig()
107111

108112
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
109113
"gateways": UpdateStrategy.KEY_UPDATE,

sqlmesh/core/context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,10 @@ def migrate(self) -> None:
14441444
"""
14451445
self.notification_target_manager.notify(NotificationEvent.MIGRATION_START)
14461446
try:
1447-
self._new_state_sync().migrate(default_catalog=self.default_catalog)
1447+
self._new_state_sync().migrate(
1448+
default_catalog=self.default_catalog,
1449+
promoted_snapshots_only=self.config.migration.promoted_snapshots_only,
1450+
)
14481451
except Exception as e:
14491452
self.notification_target_manager.notify(
14501453
NotificationEvent.MIGRATION_FAILURE, traceback.format_exc()

sqlmesh/core/state_sync/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,12 @@ def compact_intervals(self) -> None:
417417
"""
418418

419419
@abc.abstractmethod
420-
def migrate(self, default_catalog: t.Optional[str], skip_backup: bool = False) -> None:
420+
def migrate(
421+
self,
422+
default_catalog: t.Optional[str],
423+
skip_backup: bool = False,
424+
promoted_snapshots_only: bool = True,
425+
) -> None:
421426
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
422427

423428
@abc.abstractmethod

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 143 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from sqlmesh.core.state_sync.base import MIGRATIONS, SCHEMA_VERSION, StateSync, Versions
6060
from sqlmesh.core.state_sync.common import CommonStateSyncMixin, transactional
6161
from sqlmesh.utils import major_minor, random_id, unique
62+
from sqlmesh.utils.dag import DAG
6263
from sqlmesh.utils.date import TimeLike, now_timestamp, time_like_to_str
6364
from sqlmesh.utils.errors import SQLMeshError
6465
from sqlmesh.utils.pydantic import parse_obj_as
@@ -899,7 +900,12 @@ def _restore_table(
899900
)
900901

901902
@transactional()
902-
def migrate(self, default_catalog: t.Optional[str], skip_backup: bool = False) -> None:
903+
def migrate(
904+
self,
905+
default_catalog: t.Optional[str],
906+
skip_backup: bool = False,
907+
promoted_snapshots_only: bool = True,
908+
) -> None:
903909
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
904910
versions = self.get_versions(validate=False)
905911
migrations = MIGRATIONS[versions.schema_version :]
@@ -919,7 +925,7 @@ def migrate(self, default_catalog: t.Optional[str], skip_backup: bool = False) -
919925
logger.info(f"Applying migration {migration}")
920926
migration.migrate(self, default_catalog=default_catalog)
921927

922-
self._migrate_rows()
928+
self._migrate_rows(promoted_snapshots_only)
923929
self._update_versions()
924930
except Exception as e:
925931
if skip_backup:
@@ -974,117 +980,55 @@ def _backup_state(self) -> None:
974980
backup_name, exp.select("*").from_(table), exists=False
975981
)
976982

977-
def _migrate_rows(self) -> None:
978-
snapshot_mapping = self._migrate_snapshot_rows()
983+
def _migrate_rows(self, promoted_snapshots_only: bool) -> None:
984+
logger.info("Fetching environments")
985+
environments = self.get_environments()
986+
# Only migrate snapshots that are part of at least one environment.
987+
snapshots_to_migrate = (
988+
{s.snapshot_id for e in environments for s in e.snapshots}
989+
if promoted_snapshots_only
990+
else None
991+
)
992+
snapshot_mapping = self._migrate_snapshot_rows(snapshots_to_migrate)
979993
if not snapshot_mapping:
980-
logger.debug("No changes to snapshots detected")
994+
logger.info("No changes to snapshots detected")
981995
return
982996
self._migrate_seed_rows(snapshot_mapping)
983-
self._migrate_environment_rows(snapshot_mapping)
997+
self._migrate_environment_rows(environments, snapshot_mapping)
984998

985-
def _migrate_snapshot_rows(self) -> t.Dict[SnapshotId, SnapshotTableInfo]:
999+
def _migrate_snapshot_rows(
1000+
self, snapshots: t.Optional[t.Set[SnapshotId]]
1001+
) -> t.Dict[SnapshotId, SnapshotTableInfo]:
9861002
logger.info("Migrating snapshot rows...")
9871003
raw_snapshots = {
988-
SnapshotId(name=name, identifier=identifier): raw_snapshot
1004+
SnapshotId(name=name, identifier=identifier): json.loads(raw_snapshot)
1005+
for where in (self._snapshot_id_filter(snapshots) if snapshots is not None else [None])
9891006
for name, identifier, raw_snapshot in self._fetchall(
990-
exp.select("name", "identifier", "snapshot").from_(self.snapshots_table).lock()
1007+
exp.select("name", "identifier", "snapshot")
1008+
.from_(self.snapshots_table)
1009+
.where(where)
1010+
.lock()
9911011
)
9921012
}
9931013
if not raw_snapshots:
9941014
return {}
9951015

996-
self.console.start_migration_progress(len(raw_snapshots))
997-
998-
all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {}
999-
1000-
for snapshot_id_batch in self._snapshot_batches(
1001-
sorted(raw_snapshots), batch_size=self.SNAPSHOT_MIGRATION_BATCH_SIZE
1002-
):
1003-
parsed_snapshots = LazilyParsedSnapshots(raw_snapshots)
1004-
snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {}
1005-
new_snapshots: t.Dict[SnapshotId, Snapshot] = {}
1006-
1007-
for snapshot_id in snapshot_id_batch:
1008-
snapshot = parsed_snapshots[snapshot_id]
1009-
1010-
seen = set()
1011-
queue = {snapshot.snapshot_id}
1012-
node = snapshot.node
1013-
nodes: t.Dict[str, Node] = {}
1014-
audits: t.Dict[str, ModelAudit] = {}
1015-
1016-
while queue:
1017-
snapshot_id = queue.pop()
1018-
1019-
if snapshot_id in seen:
1020-
continue
1021-
1022-
seen.add(snapshot_id)
1023-
1024-
s = parsed_snapshots.get(snapshot_id)
1025-
1026-
if not s:
1027-
continue
1028-
1029-
queue.update(s.parents)
1030-
nodes[s.name] = s.node
1031-
for audit in s.audits:
1032-
audits[audit.name] = audit
1033-
1034-
new_snapshot = deepcopy(snapshot)
1035-
1036-
fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {}
1037-
1038-
try:
1039-
new_snapshot.fingerprint = fingerprint_from_node(
1040-
node,
1041-
nodes=nodes,
1042-
audits=audits,
1043-
)
1044-
new_snapshot.parents = tuple(
1045-
SnapshotId(
1046-
name=parent_node.fqn,
1047-
identifier=fingerprint_from_node(
1048-
parent_node,
1049-
nodes=nodes,
1050-
audits=audits,
1051-
cache=fingerprint_cache,
1052-
).to_identifier(),
1053-
)
1054-
for parent_node in _parents_from_node(node, nodes).values()
1055-
)
1056-
except Exception:
1057-
logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id)
1058-
continue
1016+
dag: DAG[SnapshotId] = DAG()
1017+
for snapshot_id, raw_snapshot in raw_snapshots.items():
1018+
parent_ids = [SnapshotId.parse_obj(p_id) for p_id in raw_snapshot.get("parents", [])]
1019+
dag.add(snapshot_id, [p_id for p_id in parent_ids if p_id in raw_snapshots])
10591020

1060-
new_snapshot.previous_versions = snapshot.all_versions
1061-
new_snapshot.migrated = True
1062-
if not new_snapshot.temp_version:
1063-
new_snapshot.temp_version = snapshot.fingerprint.to_version()
1021+
reversed_dag_raw = dag.reversed.graph
10641022

1065-
self.console.update_migration_progress(1)
1066-
1067-
if new_snapshot.fingerprint == snapshot.fingerprint:
1068-
logger.debug(f"{new_snapshot.snapshot_id} is unchanged.")
1069-
continue
1070-
1071-
new_snapshot_id = new_snapshot.snapshot_id
1072-
1073-
if new_snapshot_id in raw_snapshots:
1074-
# Mapped to an existing snapshot.
1075-
new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id]
1076-
elif (
1077-
new_snapshot_id not in new_snapshots
1078-
or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts
1079-
):
1080-
new_snapshots[new_snapshot_id] = new_snapshot
1081-
1082-
snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id
1083-
logger.debug(f"{snapshot.snapshot_id} mapped to {new_snapshot_id}.")
1023+
self.console.start_migration_progress(len(raw_snapshots))
10841024

1085-
if not new_snapshots:
1086-
continue
1025+
parsed_snapshots = LazilyParsedSnapshots(raw_snapshots)
1026+
all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {}
1027+
snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {}
1028+
new_snapshots: t.Dict[SnapshotId, Snapshot] = {}
1029+
visited: t.Set[SnapshotId] = set()
10871030

1031+
def _push_new_snapshots() -> None:
10881032
all_snapshot_mapping.update(
10891033
{
10901034
from_id: new_snapshots[to_id].table_info
@@ -1097,10 +1041,102 @@ def _migrate_snapshot_rows(self) -> t.Dict[SnapshotId, SnapshotTableInfo]:
10971041
s for s in new_snapshots.values() if s.snapshot_id not in existing_new_snapshots
10981042
]
10991043
if new_snapshots_to_push:
1044+
logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push))
11001045
self._push_snapshots(new_snapshots_to_push)
1046+
new_snapshots.clear()
1047+
snapshot_id_mapping.clear()
1048+
1049+
def _visit(
1050+
snapshot_id: SnapshotId, fingerprint_cache: t.Dict[str, SnapshotFingerprint]
1051+
) -> None:
1052+
if snapshot_id in visited or snapshot_id not in raw_snapshots:
1053+
return
1054+
visited.add(snapshot_id)
1055+
1056+
snapshot = parsed_snapshots[snapshot_id]
1057+
node = snapshot.node
1058+
1059+
node_seen = set()
1060+
node_queue = {snapshot_id}
1061+
nodes: t.Dict[str, Node] = {}
1062+
audits: t.Dict[str, ModelAudit] = {}
1063+
while node_queue:
1064+
next_snapshot_id = node_queue.pop()
1065+
next_snapshot = parsed_snapshots.get(next_snapshot_id)
1066+
1067+
if next_snapshot_id in node_seen or not next_snapshot:
1068+
continue
1069+
1070+
node_seen.add(next_snapshot_id)
1071+
node_queue.update(next_snapshot.parents)
1072+
1073+
nodes[next_snapshot.name] = next_snapshot.node
1074+
audits.update({a.name: a for a in next_snapshot.audits})
1075+
1076+
new_snapshot = deepcopy(snapshot)
1077+
try:
1078+
new_snapshot.fingerprint = fingerprint_from_node(
1079+
node,
1080+
nodes=nodes,
1081+
audits=audits,
1082+
cache=fingerprint_cache,
1083+
)
1084+
new_snapshot.parents = tuple(
1085+
SnapshotId(
1086+
name=parent_node.fqn,
1087+
identifier=fingerprint_from_node(
1088+
parent_node,
1089+
nodes=nodes,
1090+
audits=audits,
1091+
cache=fingerprint_cache,
1092+
).to_identifier(),
1093+
)
1094+
for parent_node in _parents_from_node(node, nodes).values()
1095+
)
1096+
except Exception:
1097+
logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id)
1098+
return
1099+
1100+
new_snapshot.previous_versions = snapshot.all_versions
1101+
new_snapshot.migrated = True
1102+
if not new_snapshot.temp_version:
1103+
new_snapshot.temp_version = snapshot.fingerprint.to_version()
1104+
1105+
self.console.update_migration_progress(1)
1106+
1107+
# Visit children and evict them from the parsed_snapshots cache after.
1108+
for child in reversed_dag_raw.get(snapshot_id, []):
1109+
# Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name.
1110+
_visit(child, fingerprint_cache.copy())
1111+
parsed_snapshots.evict(child)
1112+
1113+
if new_snapshot.fingerprint == snapshot.fingerprint:
1114+
logger.debug(f"{new_snapshot.snapshot_id} is unchanged.")
1115+
return
1116+
1117+
new_snapshot_id = new_snapshot.snapshot_id
1118+
1119+
if new_snapshot_id in raw_snapshots:
1120+
# Mapped to an existing snapshot.
1121+
new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id]
1122+
logger.debug("Migrated snapshot %s already exists", new_snapshot_id)
1123+
elif (
1124+
new_snapshot_id not in new_snapshots
1125+
or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts
1126+
):
1127+
new_snapshots[new_snapshot_id] = new_snapshot
11011128

1102-
# Force cleanup to free memory.
1103-
del parsed_snapshots
1129+
snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id
1130+
logger.debug("%s mapped to %s", snapshot.snapshot_id, new_snapshot_id)
1131+
1132+
if len(new_snapshots) >= self.SNAPSHOT_MIGRATION_BATCH_SIZE:
1133+
_push_new_snapshots()
1134+
1135+
for root_snapshot_id in dag.roots:
1136+
_visit(root_snapshot_id, {})
1137+
1138+
if new_snapshots:
1139+
_push_new_snapshots()
11041140

11051141
return all_snapshot_mapping
11061142

@@ -1146,10 +1182,11 @@ def _migrate_seed_rows(self, snapshot_mapping: t.Dict[SnapshotId, SnapshotTableI
11461182
)
11471183

11481184
def _migrate_environment_rows(
1149-
self, snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo]
1185+
self,
1186+
environments: t.List[Environment],
1187+
snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo],
11501188
) -> None:
11511189
logger.info("Migrating environment rows...")
1152-
environments = self.get_environments()
11531190

11541191
updated_prod_environment: t.Optional[Environment] = None
11551192
updated_environments = []
@@ -1355,19 +1392,22 @@ def _snapshot_to_json(snapshot: Snapshot) -> str:
13551392

13561393

13571394
class LazilyParsedSnapshots:
1358-
def __init__(self, raw_snapshots: t.Dict[SnapshotId, str]):
1395+
def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]):
13591396
self._raw_snapshots = raw_snapshots
13601397
self._parsed_snapshots: t.Dict[SnapshotId, t.Optional[Snapshot]] = {}
13611398

13621399
def get(self, snapshot_id: SnapshotId) -> t.Optional[Snapshot]:
13631400
if snapshot_id not in self._parsed_snapshots:
13641401
raw_snapshot = self._raw_snapshots.get(snapshot_id)
13651402
if raw_snapshot:
1366-
self._parsed_snapshots[snapshot_id] = Snapshot.parse_raw(raw_snapshot)
1403+
self._parsed_snapshots[snapshot_id] = Snapshot.parse_obj(raw_snapshot)
13671404
else:
13681405
self._parsed_snapshots[snapshot_id] = None
13691406
return self._parsed_snapshots[snapshot_id]
13701407

1408+
def evict(self, snapshot_id: SnapshotId) -> None:
1409+
self._parsed_snapshots.pop(snapshot_id, None)
1410+
13711411
def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot:
13721412
snapshot = self.get(snapshot_id)
13731413
if snapshot is None:

0 commit comments

Comments
 (0)