|
25 | 25 | from sqlmesh.core.snapshot.definition import ( |
26 | 26 | _parents_from_node, |
27 | 27 | ) |
28 | | -from sqlmesh.core.state_sync.base import ( |
29 | | - MIGRATIONS, |
30 | | - PRE_CHECKS, |
31 | | -) |
| 28 | +from sqlmesh.core.state_sync.base import MIGRATIONS |
32 | 29 | from sqlmesh.core.state_sync.base import StateSync |
33 | 30 | from sqlmesh.core.state_sync.db.environment import EnvironmentState |
34 | 31 | from sqlmesh.core.state_sync.db.interval import IntervalState |
@@ -102,7 +99,7 @@ def migrate( |
102 | 99 | promoted_snapshots_only: Whether to migrate only promoted snapshots. |
103 | 100 | pre_check_only: If True, only run pre-checks without performing migration. |
104 | 101 | """ |
105 | | - pre_check_warnings = self.run_pre_checks(state_sync) |
| 102 | + pre_check_warnings = self._run_pre_checks(state_sync) |
106 | 103 | should_migrate = self.console.log_pre_check_warnings(pre_check_warnings, pre_check_only) |
107 | 104 | if not should_migrate: |
108 | 105 | return |
@@ -168,30 +165,27 @@ def rollback(self) -> None: |
168 | 165 |
|
169 | 166 | logger.info("Migration rollback successful.") |
170 | 167 |
|
171 | | - def run_pre_checks(self, state_sync: StateSync) -> t.List[t.Tuple[str, t.List[str]]]: |
| 168 | + def _run_pre_checks(self, state_sync: StateSync) -> t.List[t.Tuple[str, t.List[str]]]: |
172 | 169 | """Run pre-checks for migrations between specified versions. |
173 | 170 |
|
174 | 171 | Args: |
175 | 172 | state_sync: The state sync instance. |
176 | 173 |
|
177 | 174 | Returns: |
178 | | - A list of pairs comprising the executed pre-checks and the corresponding warnings. |
| 175 | + A list of pairs comprising the migration name containing the executed pre-checks |
| 176 | + and the corresponding warnings. |
179 | 177 | """ |
180 | | - # Get the range of the migrations that would be applied |
181 | | - from_version = self.version_state.get_versions().schema_version |
182 | | - to_version = len(MIGRATIONS) |
| 178 | + versions = self.version_state.get_versions() |
| 179 | + migrations = MIGRATIONS[versions.schema_version :] |
183 | 180 |
|
184 | 181 | pre_check_warnings = [] |
185 | | - for i in range(from_version, to_version): |
186 | | - # Assumption: pre-check and migration names match |
187 | | - pre_check_name = MIGRATIONS[i].__name__.split(".")[-1] |
188 | | - pre_check_module = PRE_CHECKS.get(pre_check_name) |
189 | | - |
190 | | - if callable(pre_check := getattr(pre_check_module, "pre_check", None)): |
191 | | - logger.info(f"Running pre-check for {pre_check_name}") |
| 182 | + for migration in migrations: |
| 183 | + if callable(pre_check := getattr(migration, "pre_check", None)): |
| 184 | + migration_name = migration.__name__.split(".")[-1] |
| 185 | + logger.info(f"Running pre-check for {migration_name}") |
192 | 186 | warnings = pre_check(state_sync) |
193 | 187 | if warnings: |
194 | | - pre_check_warnings.append((pre_check_name, warnings)) |
| 188 | + pre_check_warnings.append((migration_name, warnings)) |
195 | 189 |
|
196 | 190 | return pre_check_warnings |
197 | 191 |
|
|
0 commit comments