diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 8982efc9f8..83089ed7bc 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -917,12 +917,17 @@ def ui(ctx: click.Context, host: str, port: int, mode: str) -> None: @cli.command("migrate") +@click.option( + "--pre-check", + is_flag=True, + help="Run pre-checks and display warnings without performing migration", +) @click.pass_context @error_handler @cli_analytics -def migrate(ctx: click.Context) -> None: +def migrate(ctx: click.Context, pre_check: bool) -> None: """Migrate SQLMesh to the current running version.""" - ctx.obj.migrate() + ctx.obj.migrate(pre_check_only=pre_check) @cli.command("rollback") diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index cf87fd7443..6ad18334e6 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -497,6 +497,13 @@ def update_env_migration_progress(self, num_tasks: int) -> None: def stop_env_migration_progress(self, success: bool = True) -> None: """Stop the environment migration progress.""" + @abc.abstractmethod + def log_pre_check_warnings(self, pre_check_warnings: t.List[str], pre_check_only: bool) -> bool: + """ + Log warnings emitted by pre-checks and ask user whether they'd like to + proceed with the migration (true) or not (false). + """ + @abc.abstractmethod def plan( self, @@ -662,6 +669,9 @@ def update_env_migration_progress(self, num_tasks: int) -> None: def stop_env_migration_progress(self, success: bool = True) -> None: pass + def log_pre_check_warnings(self, pre_check_warnings: t.List[str], pre_check_only: bool) -> bool: + return True + def start_state_export( self, output_file: Path, @@ -1472,6 +1482,28 @@ def stop_env_migration_progress(self, success: bool = True) -> None: if success: self.log_success("Environments migrated successfully") + def log_pre_check_warnings(self, pre_check_warnings: t.List[str], pre_check_only: bool) -> bool: + if pre_check_warnings: + tree = Tree(f"[bold]Pre-migration warnings[/bold]") + for warning in pre_check_warnings: + tree.add(f"[yellow]{warning}[/yellow]") + + self._print(tree) + + if pre_check_only: + return False + + should_continue = self._confirm("\nDo you want to proceed with the migration?") + if not should_continue: + self.log_status_update("Migration cancelled.") + + return should_continue + if pre_check_only: + self.log_status_update("No pre-migration warnings detected.") + return False + + return True + def start_state_export( self, output_file: Path, diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index c0d9b21ff8..04520b35c0 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2319,24 +2319,33 @@ def check_intervals( return results @python_api_analytics - def migrate(self) -> None: + def migrate(self, pre_check_only: bool = False) -> None: """Migrates SQLMesh to the current running version. Please contact your SQLMesh administrator before doing this. + + Args: + pre_check_only: If True, only run pre-checks without performing the migration. """ - self.notification_target_manager.notify(NotificationEvent.MIGRATION_START) + if not pre_check_only: + self.notification_target_manager.notify(NotificationEvent.MIGRATION_START) + self._load_materializations() try: self._new_state_sync().migrate( default_catalog=self.default_catalog, promoted_snapshots_only=self.config.migration.promoted_snapshots_only, + pre_check_only=pre_check_only, ) except Exception as e: - self.notification_target_manager.notify( - NotificationEvent.MIGRATION_FAILURE, traceback.format_exc() - ) + if not pre_check_only: + self.notification_target_manager.notify( + NotificationEvent.MIGRATION_FAILURE, traceback.format_exc() + ) raise e - self.notification_target_manager.notify(NotificationEvent.MIGRATION_END) + + if not pre_check_only: + self.notification_target_manager.notify(NotificationEvent.MIGRATION_END) @python_api_analytics def rollback(self) -> None: diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 6c2097d760..fa9841686b 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -40,6 +40,7 @@ class Versions(PydanticModel): schema_version: int = 0 sqlglot_version: str = "0.0.0" sqlmesh_version: str = "0.0.0" + pre_check_version: int = 0 @property def minor_sqlglot_version(self) -> t.Tuple[int, int]: @@ -54,9 +55,9 @@ def minor_sqlmesh_version(self) -> t.Tuple[int, int]: def _package_version_validator(cls, v: t.Any) -> str: return "0.0.0" if v is None else str(v) - @field_validator("schema_version", mode="before") + @field_validator("schema_version", "pre_check_version", mode="before") @classmethod - def _schema_version_validator(cls, v: t.Any) -> int: + def _int_version_validator(cls, v: t.Any) -> int: return 0 if v is None else int(v) @@ -65,6 +66,13 @@ def _schema_version_validator(cls, v: t.Any) -> int: for migration in sorted(info.name for info in pkgutil.iter_modules(migrations.__path__)) ] SCHEMA_VERSION: int = len(MIGRATIONS) +PRE_CHECK_VERSION: int = ( + max( + [idx for idx, migration in enumerate(MIGRATIONS) if hasattr(migration, "pre_check")], + default=-1, + ) + + 1 +) class PromotionResult(PydanticModel): @@ -456,6 +464,7 @@ def migrate( default_catalog: t.Optional[str], skip_backup: bool = False, promoted_snapshots_only: bool = True, + pre_check_only: bool = False, ) -> None: """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py index 779add1cca..765545a334 100644 --- a/sqlmesh/core/state_sync/db/facade.py +++ b/sqlmesh/core/state_sync/db/facade.py @@ -447,6 +447,7 @@ def migrate( default_catalog: t.Optional[str], skip_backup: bool = False, promoted_snapshots_only: bool = True, + pre_check_only: bool = False, ) -> None: """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" self.migrator.migrate( @@ -454,6 +455,7 @@ def migrate( default_catalog, skip_backup=skip_backup, promoted_snapshots_only=promoted_snapshots_only, + pre_check_only=pre_check_only, ) @transactional() diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py index 405c0ea667..770ce57f66 100644 --- a/sqlmesh/core/state_sync/db/migrator.py +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -25,9 +25,7 @@ from sqlmesh.core.snapshot.definition import ( _parents_from_node, ) -from sqlmesh.core.state_sync.base import ( - MIGRATIONS, -) +from sqlmesh.core.state_sync.base import MIGRATIONS from sqlmesh.core.state_sync.base import StateSync from sqlmesh.core.state_sync.db.environment import EnvironmentState from sqlmesh.core.state_sync.db.interval import IntervalState @@ -90,8 +88,22 @@ def migrate( default_catalog: t.Optional[str], skip_backup: bool = False, promoted_snapshots_only: bool = True, + pre_check_only: bool = False, ) -> None: - """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" + """Migrate the state sync to the latest SQLMesh / SQLGlot version. + + Args: + state_sync: The state sync instance. + default_catalog: The default catalog. + skip_backup: Whether to skip backing up state tables. + promoted_snapshots_only: Whether to migrate only promoted snapshots. + pre_check_only: If True, only run pre-checks without performing migration. + """ + pre_check_warnings = self._run_pre_checks(state_sync) + should_migrate = self.console.log_pre_check_warnings(pre_check_warnings, pre_check_only) + if not should_migrate: + return + versions = self.version_state.get_versions() migration_start_ts = time.perf_counter() @@ -153,6 +165,30 @@ def rollback(self) -> None: logger.info("Migration rollback successful.") + def _run_pre_checks(self, state_sync: StateSync) -> t.List[str]: + """Run pre-checks for migrations between specified versions. + + Args: + state_sync: The state sync instance. + + Returns: + A list of pairs comprising the migration name containing the executed pre-checks + and the corresponding warnings. + """ + versions = self.version_state.get_versions() + migrations = MIGRATIONS[versions.schema_version :] + + pre_check_warnings = [] + for migration in migrations: + if callable(pre_check := getattr(migration, "pre_check", None)): + migration_name = migration.__name__.split(".")[-1] + logger.info(f"Running pre-check for {migration_name}") + warnings = pre_check(state_sync) + if warnings: + pre_check_warnings.extend(warnings) + + return pre_check_warnings + def _apply_migrations( self, state_sync: StateSync, diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py index 873e1633df..9374fd90a5 100644 --- a/sqlmesh/core/state_sync/db/version.py +++ b/sqlmesh/core/state_sync/db/version.py @@ -13,6 +13,7 @@ SQLMESH_VERSION, ) from sqlmesh.core.state_sync.base import ( + PRE_CHECK_VERSION, SCHEMA_VERSION, Versions, ) @@ -31,6 +32,7 @@ def __init__(self, engine_adapter: EngineAdapter, schema: t.Optional[str] = None "schema_version": exp.DataType.build("int"), "sqlglot_version": exp.DataType.build(index_type), "sqlmesh_version": exp.DataType.build(index_type), + "pre_check_version": exp.DataType.build("int"), } def update_versions( @@ -38,6 +40,7 @@ def update_versions( schema_version: int = SCHEMA_VERSION, sqlglot_version: str = SQLGLOT_VERSION, sqlmesh_version: str = SQLMESH_VERSION, + pre_check_version: int = PRE_CHECK_VERSION, ) -> None: import pandas as pd @@ -51,6 +54,7 @@ def update_versions( "schema_version": schema_version, "sqlglot_version": sqlglot_version, "sqlmesh_version": sqlmesh_version, + "pre_check_version": pre_check_version, } ] ), @@ -69,5 +73,8 @@ def get_versions(self) -> Versions: return no_version return Versions( - schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) + schema_version=row[0], + sqlglot_version=row[1], + sqlmesh_version=seq_get(row, 2), + pre_check_version=seq_get(row, 3), ) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 454b6cd4ce..328ece80bf 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -709,11 +709,17 @@ def dag(self, context: Context, line: str) -> None: self.display(dag) @magic_arguments() + @argument( + "--pre-check", + action="store_true", + help="Run pre-checks and display warnings without performing migration", + ) @line_magic @pass_sqlmesh_context def migrate(self, context: Context, line: str) -> None: """Migrate SQLMesh to the current running version.""" - context.migrate() + args = parse_argstring(self.migrate, line) + context.migrate(pre_check_only=args.pre_check) context.console.log_success("Migration complete") @magic_arguments() diff --git a/sqlmesh/migrations/v0089_add_pre_check_version.py b/sqlmesh/migrations/v0089_add_pre_check_version.py new file mode 100644 index 0000000000..d470b66f02 --- /dev/null +++ b/sqlmesh/migrations/v0089_add_pre_check_version.py @@ -0,0 +1,24 @@ +"""Add new 'pre_check_version' column to the version state table.""" + +from sqlglot import exp + + +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + versions_table = "_versions" + if schema: + versions_table = f"{schema}.{versions_table}" + + alter_table_exp = exp.Alter( + this=exp.to_table(versions_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("pre_check_version"), + kind=exp.DataType.build("int"), + ) + ], + ) + + engine_adapter.execute(alter_table_exp) diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py index a5a6969e38..7e27634c2b 100644 --- a/tests/core/state_sync/test_state_sync.py +++ b/tests/core/state_sync/test_state_sync.py @@ -2,6 +2,7 @@ import logging import re import typing as t +from types import ModuleType from unittest.mock import call, patch import duckdb # noqa: TID253 @@ -11,8 +12,16 @@ from pytest_mock.plugin import MockerFixture from sqlglot import exp +from sqlmesh.cli.project_init import init_example_project from sqlmesh.core import constants as c -from sqlmesh.core.config import EnvironmentSuffixTarget +from sqlmesh.core.config import ( + Config, + DuckDBConnectionConfig, + EnvironmentSuffixTarget, + GatewayConfig, + ModelDefaultsConfig, +) +from sqlmesh.core.context import Context from sqlmesh.core.dialect import parse_one, schema_ from sqlmesh.core.engine_adapter import create_engine_adapter from sqlmesh.core.environment import Environment, EnvironmentStatements @@ -48,6 +57,7 @@ ) from sqlmesh.utils.date import now_timestamp, to_datetime, to_timestamp from sqlmesh.utils.errors import SQLMeshError +from tests.utils.test_helpers import use_terminal_console pytestmark = pytest.mark.slow @@ -3629,3 +3639,78 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync): "@grant_schema_usage()", "@grant_select_privileges()", ] + + +@use_terminal_console +def test_pre_checks(tmp_path, mocker): + init_example_project(tmp_path, engine_type="duckdb") + + db_path = str(tmp_path / "db.db") + config = Config( + gateways={"main": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path))}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + context = Context(paths=tmp_path, config=config) + context.plan(auto_apply=True, no_prompts=True) + + def mock_migrate(state_sync, **kwargs): + pass + + def mock_pre_check_with_warnings(state_sync): + return [ + "Warning: This migration will break compatibility with older versions", + "Warning: You must update all model configurations before applying this migration", + "Warning: Existing snapshots will need to be rebuilt", + ] + + def mock_pre_check_without_warnings(state_sync): + return [] + + # Create a mock migration module with a pre_check function + mock_migration = ModuleType("v9999_test_pre_check") + + setattr(mock_migration, "migrate", mock_migrate) + setattr(mock_migration, "pre_check", mock_pre_check_with_warnings) + + versions_before_migrate = context.state_sync.get_versions() + + import sqlmesh.core.state_sync as state_sync + + test_migrations = state_sync.db.migrator.MIGRATIONS + [mock_migration] + + # Test 1: Pre-check warnings are properly collected and displayed, user rejects migration + with ( + patch.object(state_sync.db.migrator, "MIGRATIONS", test_migrations), + patch.object(context.console, "_confirm", return_value=False), + ): + console = context.console + log_pre_check_warnings_spy = mocker.spy(console, "log_pre_check_warnings") + + context.migrate(pre_check_only=False) + + calls = log_pre_check_warnings_spy.mock_calls + assert len(calls) == 1 + + pre_check_warnings = calls[0].args[0] + assert len(pre_check_warnings) == 3 + assert all(warning.startswith("Warning:") for warning in pre_check_warnings) + + assert context.state_sync.get_versions() == versions_before_migrate + + update_versions_spy = mocker.spy(state_sync.db.version.VersionState, "update_versions") + + # Test 2: User accepts migration after being notified about pre-check warnings + with ( + patch.object(state_sync.db.migrator, "MIGRATIONS", test_migrations), + patch.object(context.console, "_confirm", return_value=True), + ): + context.migrate(pre_check_only=False) + assert len(update_versions_spy.mock_calls) == 1 + + # Test 3: Pre-check without warning should automatically reuslt in a migration + setattr(mock_migration, "pre_check", mock_pre_check_without_warnings) + with patch.object(state_sync.db.migrator, "MIGRATIONS", test_migrations): + # Since the version module's SCHEMA_VERSION, etc, weren't patched, the old versions + # are still used, so the following should result in hitting the update_versions path + context.migrate(pre_check_only=False) + assert len(update_versions_spy.mock_calls) == 2