diff --git a/docs/cloud/tcloud_getting_started.md b/docs/cloud/tcloud_getting_started.md index dc0491814d..00ad8a3c25 100644 --- a/docs/cloud/tcloud_getting_started.md +++ b/docs/cloud/tcloud_getting_started.md @@ -268,9 +268,6 @@ Models needing backfill (missing dates): ├── sqlmesh_example.incremental_model: 2020-01-01 - 2024-11-24 └── sqlmesh_example.seed_model: 2024-11-24 - 2024-11-24 Apply - Backfill Tables [y/n]: y -Creating physical tables ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 - -All model versions have been created successfully [1/1] sqlmesh_example.seed_model evaluated in 0.00s [1/1] sqlmesh_example.incremental_model evaluated in 0.01s diff --git a/docs/concepts/macros/macro_variables.md b/docs/concepts/macros/macro_variables.md index a184f7d99f..398117b3a9 100644 --- a/docs/concepts/macros/macro_variables.md +++ b/docs/concepts/macros/macro_variables.md @@ -130,8 +130,8 @@ SQLMesh provides additional predefined variables used to modify model behavior b * @runtime_stage - A string value denoting the current stage of the SQLMesh runtime. Typically used in models to conditionally execute pre/post-statements (learn more [here](../models/sql_models.md#optional-prepost-statements)). It returns one of these values: * 'loading' - The project is being loaded into SQLMesh's runtime context. - * 'creating' - The model tables are being created. - * 'evaluating' - The model query logic is being evaluated. + * 'creating' - The model tables are being created for the first time. The data may be inserted during table creation. + * 'evaluating' - The model query logic is evaluated, and the data is inserted into the existing model table. * 'promoting' - The model is being promoted in the target environment (view created during virtual layer update). * 'demoting' - The model is being demoted in the target environment (view dropped during virtual layer update). * 'auditing' - The audit is being run. diff --git a/docs/examples/incremental_time_full_walkthrough.md b/docs/examples/incremental_time_full_walkthrough.md index 6907836b0b..4e1d577d2c 100644 --- a/docs/examples/incremental_time_full_walkthrough.md +++ b/docs/examples/incremental_time_full_walkthrough.md @@ -304,10 +304,6 @@ Models needing backfill (missing dates): Enter the backfill start date (eg. '1 year', '2020-01-01') or blank to backfill from the beginning of history: Enter the backfill end date (eg. '1 month ago', '2020-01-01') or blank to backfill up until now: Apply - Backfill Tables [y/n]: y -Creating physical table ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:07 - -All model versions have been created successfully - [1/1] demo__dev.incrementals_demo evaluated in 6.97s Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:06 @@ -640,9 +636,10 @@ Models: ├── tcloud_raw_data.product_usage └── tcloud_raw_data.sales Apply - Virtual Update [y/n]: y -Creating physical tables ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 15/15 • 0:00:03 -All model versions have been created successfully +SKIP: No physical layer updates to perform + +SKIP: No model batches to execute Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:02 @@ -732,10 +729,6 @@ Models needing backfill (missing dates): Enter the preview start date (eg. '1 year', '2020-01-01') or blank to backfill to preview starting from yesterday: 2024-10-27 Enter the preview end date (eg. '1 month ago', '2020-01-01') or blank to preview up until '2024-11-08 00:00:00': Apply - Preview Tables [y/n]: y -Creating physical table ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:16 - -All model versions have been created successfully - [1/1] demo__dev.incrementals_demo evaluated in 6.18s Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:06 @@ -1249,9 +1242,10 @@ Models: THEN 'Regular User' Directly Modified: demo.incrementals_demo (Forward-only) Apply - Virtual Update [y/n]: y -Creating physical tables ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 15/15 • 0:00:02 -All model versions have been created successfully +SKIP: No physical layer updates to perform + +SKIP: No model batches to execute Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:02 diff --git a/docs/integrations/dlt.md b/docs/integrations/dlt.md index d772a7bd2c..a53dc184ea 100644 --- a/docs/integrations/dlt.md +++ b/docs/integrations/dlt.md @@ -102,10 +102,6 @@ Models needing backfill (missing dates): ├── sushi_dataset_sqlmesh.incremental_sushi_types: 2024-10-03 - 2024-10-03 └── sushi_dataset_sqlmesh.incremental_waiters: 2024-10-03 - 2024-10-03 Apply - Backfill Tables [y/n]: y -Creating physical table ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 - -All model versions have been created successfully - [1/1] sushi_dataset_sqlmesh.incremental__dlt_loads evaluated in 0.01s [1/1] sushi_dataset_sqlmesh.incremental_sushi_types evaluated in 0.00s [1/1] sushi_dataset_sqlmesh.incremental_waiters evaluated in 0.01s diff --git a/examples/multi/repo_1/models/a.sql b/examples/multi/repo_1/models/a.sql index 838a676ea1..31ef81b2d7 100644 --- a/examples/multi/repo_1/models/a.sql +++ b/examples/multi/repo_1/models/a.sql @@ -1,5 +1,6 @@ MODEL ( - name bronze.a + name bronze.a, + kind FULL ); SELECT diff --git a/examples/multi/repo_1/models/b.sql b/examples/multi/repo_1/models/b.sql index b32897705e..b80918d6d5 100644 --- a/examples/multi/repo_1/models/b.sql +++ b/examples/multi/repo_1/models/b.sql @@ -1,5 +1,6 @@ MODEL ( - name bronze.b + name bronze.b, + kind FULL ); SELECT diff --git a/examples/multi/repo_2/models/c.sql b/examples/multi/repo_2/models/c.sql index 6a5c42619c..08551704f4 100644 --- a/examples/multi/repo_2/models/c.sql +++ b/examples/multi/repo_2/models/c.sql @@ -1,5 +1,6 @@ MODEL ( - name silver.c + name silver.c, + kind FULL ); SELECT DISTINCT col_a diff --git a/examples/multi/repo_2/models/d.sql b/examples/multi/repo_2/models/d.sql index 3647ab6965..6935763f59 100644 --- a/examples/multi/repo_2/models/d.sql +++ b/examples/multi/repo_2/models/d.sql @@ -1,5 +1,6 @@ MODEL ( - name silver.d + name silver.d, + kind FULL ); SELECT diff --git a/examples/multi/repo_2/models/e.sql b/examples/multi/repo_2/models/e.sql index 34d0793328..168dbc143d 100644 --- a/examples/multi/repo_2/models/e.sql +++ b/examples/multi/repo_2/models/e.sql @@ -1,5 +1,6 @@ MODEL ( - name silver.e + name silver.e, + kind FULL ); SELECT diff --git a/examples/sushi/config.py b/examples/sushi/config.py index 0bf15d2767..b985e24ec5 100644 --- a/examples/sushi/config.py +++ b/examples/sushi/config.py @@ -1,6 +1,6 @@ import os -from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.config.common import VirtualEnvironmentMode, TableNamingConvention from sqlmesh.core.config import ( AutoCategorizationMode, BigQueryConnectionConfig, @@ -27,6 +27,11 @@ defaults = {"dialect": "duckdb"} model_defaults = ModelDefaultsConfig(**defaults) model_defaults_iceberg = ModelDefaultsConfig(**defaults, storage_format="iceberg") +before_all = [ + "CREATE SCHEMA IF NOT EXISTS raw", + "DROP VIEW IF EXISTS raw.demographics", + "CREATE VIEW raw.demographics AS (SELECT 1 AS customer_id, '00000' AS zip)", +] # A DuckDB config, in-memory by default. @@ -52,6 +57,7 @@ "nomissingexternalmodels", ], ), + before_all=before_all, ) bigquery_config = Config( @@ -63,6 +69,7 @@ }, default_gateway="bq", model_defaults=model_defaults, + before_all=before_all, ) # A configuration used for SQLMesh tests. @@ -75,6 +82,7 @@ ) ), model_defaults=model_defaults, + before_all=before_all, ) # A configuration used for SQLMesh tests with virtual environment mode set to DEV_ONLY. @@ -84,7 +92,7 @@ "plan": PlanConfig( auto_categorize_changes=CategorizerConfig.all_full(), ), - } + }, ) # A DuckDB config with a physical schema map. @@ -92,6 +100,7 @@ default_connection=DuckDBConnectionConfig(), physical_schema_mapping={"^sushi$": "company_internal"}, model_defaults=model_defaults, + before_all=before_all, ) # A config representing isolated systems with a gateway per system @@ -103,6 +112,7 @@ }, default_gateway="dev", model_defaults=model_defaults, + before_all=before_all, ) required_approvers_config = Config( @@ -137,6 +147,7 @@ ), ], model_defaults=model_defaults, + before_all=before_all, ) @@ -144,12 +155,13 @@ default_connection=DuckDBConnectionConfig(), model_defaults=model_defaults, environment_suffix_target=EnvironmentSuffixTarget.TABLE, + before_all=before_all, ) environment_suffix_catalog_config = environment_suffix_table_config.model_copy( update={ "environment_suffix_target": EnvironmentSuffixTarget.CATALOG, - } + }, ) CATALOGS = { @@ -161,6 +173,7 @@ default_connection=DuckDBConnectionConfig(catalogs=CATALOGS), default_test_connection=DuckDBConnectionConfig(catalogs=CATALOGS), model_defaults=model_defaults, + before_all=before_all, ) environment_catalog_mapping_config = Config( @@ -177,4 +190,13 @@ "^prod$": "prod_catalog", ".*": "dev_catalog", }, + before_all=before_all, +) + +hash_md5_naming_config = config.copy( + update={"physical_table_naming_convention": TableNamingConvention.HASH_MD5} +) + +table_only_naming_config = config.copy( + update={"physical_table_naming_convention": TableNamingConvention.TABLE_ONLY} ) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index cf87fd7443..43283ead90 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -4074,8 +4074,8 @@ def _format_node_error(ex: NodeExecutionFailedError) -> str: node_name = "" if isinstance(error.node, SnapshotId): node_name = error.node.name - elif isinstance(error.node, tuple): - node_name = error.node[0] + elif hasattr(error.node, "snapshot_name"): + node_name = error.node.snapshot_name msg = _format_node_error(error) msg = " " + msg.replace("\n", "\n ") diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index ff19a3c7c6..12da39f50f 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -222,7 +222,9 @@ def create( infer_python_dependencies=infer_python_dependencies, ) - previous_environment_statements = state_reader.get_environment_statements(environment) + previous_environment_statements = ( + state_reader.get_environment_statements(env.name) if env else [] + ) if existing_env and always_recreate_environment: previous_plan_id: t.Optional[str] = existing_env.plan_id @@ -288,7 +290,7 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD previous_finalized_snapshots=env.previous_finalized_snapshots, previous_requirements=env.requirements, requirements=env.requirements, - previous_environment_statements=[], + previous_environment_statements=environment_statements, environment_statements=environment_statements, previous_gateway_managed_virtual_layer=env.gateway_managed, gateway_managed_virtual_layer=env.gateway_managed, diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 94ffbe81d2..24ee99bba5 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -443,14 +443,20 @@ def replace_query( target_table=target_table, source_columns=source_columns, ) + if not target_columns_to_types and table_exists: + target_columns_to_types = self.columns(target_table) query = source_queries[0].query_factory() - target_columns_to_types = target_columns_to_types or self.columns(target_table) self_referencing = any( quote_identifiers(table) == quote_identifiers(target_table) for table in query.find_all(exp.Table) ) # If a query references itself then it must have a table created regardless of approach used. if self_referencing: + if not target_columns_to_types: + raise SQLMeshError( + f"Cannot create a self-referencing table {target_table.sql(dialect=self.dialect)} without knowing the column types. " + "Try casting the columns to an expected type or defining the columns in the model metadata. " + ) self._create_table_from_columns( target_table, target_columns_to_types, @@ -472,6 +478,7 @@ def replace_query( **kwargs, ) if self_referencing: + assert target_columns_to_types is not None with self.temp_table( self._select_columns(target_columns_to_types).from_(target_table), name=target_table, diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index cc394efd9e..aa46fba95a 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -53,7 +53,9 @@ def columns( self.execute(sql) resp = self.cursor.fetchall() if not resp: - raise SQLMeshError("Could not get columns for table '%s'. Table not found.", table_name) + raise SQLMeshError( + f"Could not get columns for table '{table.sql(dialect=self.dialect)}'. Table not found." + ) return { column_name: exp.DataType.build(data_type, dialect=self.dialect, udt=True) for column_name, data_type in resp diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index a1adca56fb..fbe0b7bca9 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -16,7 +16,7 @@ ) from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.plan.common import should_force_rebuild +from sqlmesh.core.plan.common import should_force_rebuild, is_breaking_kind_change from sqlmesh.core.plan.definition import ( Plan, SnapshotMapping, @@ -597,7 +597,7 @@ def _categorize_snapshots( forward_only = self._forward_only or self._is_forward_only_change(s_id) if forward_only and s_id.name in self._context_diff.modified_snapshots: new, old = self._context_diff.modified_snapshots[s_id.name] - if should_force_rebuild(old, new) or snapshot.is_seed: + if is_breaking_kind_change(old, new) or snapshot.is_seed: # Breaking kind changes and seed changes can't be forward-only. forward_only = False @@ -622,7 +622,7 @@ def _categorize_snapshot( if self._context_diff.directly_modified(s_id.name): if self._auto_categorization_enabled: new, old = self._context_diff.modified_snapshots[s_id.name] - if should_force_rebuild(old, new): + if is_breaking_kind_change(old, new): snapshot.categorize_as(SnapshotChangeCategory.BREAKING, False) return @@ -780,7 +780,7 @@ def _is_forward_only_change(self, s_id: SnapshotId) -> bool: if snapshot.name in self._context_diff.modified_snapshots: _, old = self._context_diff.modified_snapshots[snapshot.name] # If the model kind has changed in a breaking way, then we can't consider this to be a forward-only change. - if snapshot.is_model and should_force_rebuild(old, snapshot): + if snapshot.is_model and is_breaking_kind_change(old, snapshot): return False return ( snapshot.is_model and snapshot.model.forward_only and bool(snapshot.previous_versions) diff --git a/sqlmesh/core/plan/common.py b/sqlmesh/core/plan/common.py index e6b7a4d10c..8d31b0ead3 100644 --- a/sqlmesh/core/plan/common.py +++ b/sqlmesh/core/plan/common.py @@ -4,6 +4,13 @@ def should_force_rebuild(old: Snapshot, new: Snapshot) -> bool: + if new.is_view and new.is_indirect_non_breaking and not new.is_forward_only: + # View models always need to be rebuilt to reflect updated upstream dependencies. + return True + return is_breaking_kind_change(old, new) + + +def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool: if old.virtual_environment_mode != new.virtual_environment_mode: # If the virtual environment mode has changed, then we need to rebuild return True diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index ced4631b99..46142b7eeb 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -205,6 +205,16 @@ def visit_physical_layer_update_stage( success=completion_status is not None and completion_status.is_success ) + def visit_physical_layer_schema_creation_stage( + self, stage: stages.PhysicalLayerSchemaCreationStage, plan: EvaluatablePlan + ) -> None: + try: + self.snapshot_evaluator.create_physical_schemas( + stage.snapshots, stage.deployability_index + ) + except Exception as ex: + raise PlanError("Plan application failed.") from ex + def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePlan) -> None: if plan.empty_backfill: intervals_to_add = [] @@ -243,6 +253,8 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla circuit_breaker=self._circuit_breaker, start=plan.start, end=plan.end, + allow_destructive_snapshots=plan.allow_destructive_models, + selected_snapshot_ids=stage.selected_snapshot_ids, ) if errors: raise PlanError("Plan application failed.") diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index 871b540203..82223dd807 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -1,6 +1,7 @@ import typing as t from dataclasses import dataclass +from sqlmesh.core import constants as c from sqlmesh.core.environment import EnvironmentStatements, EnvironmentNamingInfo, Environment from sqlmesh.core.plan.common import should_force_rebuild from sqlmesh.core.plan.definition import EvaluatablePlan @@ -71,6 +72,19 @@ class PhysicalLayerUpdateStage: deployability_index: DeployabilityIndex +@dataclass +class PhysicalLayerSchemaCreationStage: + """Create the physical schemas for the given snapshots. + + Args: + snapshots: Snapshots to create physical schemas for. + deployability_index: Deployability index for this stage. + """ + + snapshots: t.List[Snapshot] + deployability_index: DeployabilityIndex + + @dataclass class AuditOnlyRunStage: """Run audits only for given snapshots. @@ -102,12 +116,14 @@ class BackfillStage: Args: snapshot_to_intervals: Intervals to backfill. This collection can be empty in which case no backfill is needed. This can be useful to report the lack of backfills back to the user. + selected_snapshot_ids: The snapshots to include in the run DAG. all_snapshots: All snapshots in the plan by name. deployability_index: Deployability index for this stage. before_promote: Whether this stage is before the promotion stage. """ snapshot_to_intervals: SnapshotToIntervals + selected_snapshot_ids: t.Set[SnapshotId] all_snapshots: t.Dict[str, Snapshot] deployability_index: DeployabilityIndex before_promote: bool = True @@ -185,6 +201,7 @@ class FinalizeEnvironmentStage: AfterAllStage, CreateSnapshotRecordsStage, PhysicalLayerUpdateStage, + PhysicalLayerSchemaCreationStage, AuditOnlyRunStage, RestatementStage, BackfillStage, @@ -236,7 +253,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: self._adjust_intervals(snapshots_by_name, plan, existing_environment) deployability_index = DeployabilityIndex.create(snapshots, start=plan.start) - deployability_index_for_creation = deployability_index if plan.is_dev: before_promote_snapshots = all_selected_for_backfill_snapshots after_promote_snapshots = set() @@ -283,11 +299,23 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: if plan.new_snapshots: stages.append(CreateSnapshotRecordsStage(snapshots=plan.new_snapshots)) - stages.append( - self._get_physical_layer_update_stage( - plan, snapshots, snapshots_to_intervals, deployability_index_for_creation + snapshots_to_create = self._get_snapshots_to_create(plan, snapshots) + if snapshots_to_create: + stages.append( + PhysicalLayerSchemaCreationStage( + snapshots=snapshots_to_create, deployability_index=deployability_index + ) + ) + if not needs_backfill: + stages.append( + self._get_physical_layer_update_stage( + plan, + snapshots_to_create, + snapshots, + snapshots_to_intervals, + deployability_index, + ) ) - ) audit_only_snapshots = self._get_audit_only_snapshots(new_snapshots) if audit_only_snapshots: @@ -301,6 +329,11 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: stages.append( BackfillStage( snapshot_to_intervals=missing_intervals_before_promote, + selected_snapshot_ids={ + s_id + for s_id in before_promote_snapshots + if plan.is_selected_for_backfill(s_id.name) + }, all_snapshots=snapshots_by_name, deployability_index=deployability_index, ) @@ -310,6 +343,7 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: stages.append( BackfillStage( snapshot_to_intervals={}, + selected_snapshot_ids=set(), all_snapshots=snapshots_by_name, deployability_index=deployability_index, ) @@ -326,7 +360,7 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: MigrateSchemasStage( snapshots=snapshots_with_schema_migration, all_snapshots=snapshots, - deployability_index=deployability_index_for_creation, + deployability_index=deployability_index, ) ) @@ -340,6 +374,11 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: stages.append( BackfillStage( snapshot_to_intervals=missing_intervals_after_promote, + selected_snapshot_ids={ + s_id + for s_id in after_promote_snapshots + if plan.is_selected_for_backfill(s_id.name) + }, all_snapshots=snapshots_by_name, deployability_index=deployability_index, ) @@ -418,13 +457,14 @@ def _get_restatement_stage( def _get_physical_layer_update_stage( self, plan: EvaluatablePlan, - snapshots: t.Dict[SnapshotId, Snapshot], + snapshots_to_create: t.List[Snapshot], + all_snapshots: t.Dict[SnapshotId, Snapshot], snapshots_to_intervals: SnapshotToIntervals, deployability_index: DeployabilityIndex, ) -> PhysicalLayerUpdateStage: return PhysicalLayerUpdateStage( - snapshots=self._get_snapshots_to_create(plan, snapshots), - all_snapshots=snapshots, + snapshots=snapshots_to_create, + all_snapshots=all_snapshots, snapshots_with_missing_intervals={ s.snapshot_id for s in snapshots_to_intervals @@ -589,6 +629,9 @@ def _adjust_intervals( # Make sure the intervals are up to date and restatements are reflected self.state_reader.refresh_snapshot_intervals(snapshots_by_name.values()) + if not existing_environment: + existing_environment = self.state_reader.get_environment(c.PROD) + if existing_environment: new_snapshot_ids = set() new_snapshot_versions = set() diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 4582b24485..e787e57a23 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -1,4 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass +import abc import logging import typing as t import time @@ -39,7 +41,6 @@ from sqlmesh.utils.date import ( TimeLike, now_timestamp, - to_timestamp, validate_date_range, ) from sqlmesh.utils.errors import ( @@ -55,9 +56,43 @@ logger = logging.getLogger(__name__) SnapshotToIntervals = t.Dict[Snapshot, Intervals] -# we store snapshot name instead of snapshots/snapshotids because pydantic -# is extremely slow to hash. snapshot names should be unique within a dag run -SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]] + + +class SchedulingUnit(abc.ABC): + snapshot_name: str + + def __lt__(self, other: SchedulingUnit) -> bool: + return (self.__class__.__name__, self.snapshot_name) < ( + other.__class__.__name__, + other.snapshot_name, + ) + + +@dataclass(frozen=True) +class EvaluateNode(SchedulingUnit): + snapshot_name: str + interval: Interval + batch_index: int + + def __lt__(self, other: SchedulingUnit) -> bool: + if not isinstance(other, EvaluateNode): + return super().__lt__(other) + return (self.__class__.__name__, self.snapshot_name, self.interval, self.batch_index) < ( + other.__class__.__name__, + other.snapshot_name, + other.interval, + other.batch_index, + ) + + +@dataclass(frozen=True) +class CreateNode(SchedulingUnit): + snapshot_name: str + + +@dataclass(frozen=True) +class DummyNode(SchedulingUnit): + snapshot_name: str class Scheduler: @@ -161,6 +196,8 @@ def evaluate( deployability_index: DeployabilityIndex, batch_index: int, environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + target_table_exists: t.Optional[bool] = None, **kwargs: t.Any, ) -> t.List[AuditResult]: """Evaluate a snapshot and add the processed interval to the state sync. @@ -170,9 +207,11 @@ def evaluate( start: The start datetime to render. end: The end datetime to render. execution_time: The date/time time reference to use for execution time. Defaults to now. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. deployability_index: Determines snapshots that are deployable in the context of this evaluation. batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it auto_restatement_enabled: Whether to enable auto restatements. + target_table_exists: Whether the target table exists. If None, the table will be checked for existence. kwargs: Additional kwargs to pass to the renderer. Returns: @@ -190,8 +229,10 @@ def evaluate( end=end, execution_time=execution_time, snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots, deployability_index=deployability_index, batch_index=batch_index, + target_table_exists=target_table_exists, **kwargs, ) audit_results = self._audit_snapshot( @@ -289,8 +330,9 @@ def batch_intervals( merged_intervals: SnapshotToIntervals, deployability_index: t.Optional[DeployabilityIndex], environment_naming_info: EnvironmentNamingInfo, + dag: t.Optional[DAG[SnapshotId]] = None, ) -> t.Dict[Snapshot, Intervals]: - dag = snapshots_to_dag(merged_intervals) + dag = dag or snapshots_to_dag(merged_intervals) snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = { snapshot.snapshot_id: ( @@ -369,6 +411,8 @@ def run_merged_intervals( circuit_breaker: t.Optional[t.Callable[[], bool]] = None, start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, audit_only: bool = False, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: @@ -382,14 +426,22 @@ def run_merged_intervals( circuit_breaker: An optional handler which checks if the run should be aborted. start: The start of the run. end: The end of the run. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. + selected_snapshot_ids: The snapshots to include in the run DAG. If None, all snapshots with missing intervals will be included. Returns: A tuple of errors and skipped intervals. """ execution_time = execution_time or now_timestamp() + selected_snapshots = [self.snapshots[sid] for sid in (selected_snapshot_ids or set())] + if not selected_snapshots: + selected_snapshots = list(merged_intervals) + + snapshot_dag = snapshots_to_dag(selected_snapshots) + batched_intervals = self.batch_intervals( - merged_intervals, deployability_index, environment_naming_info + merged_intervals, deployability_index, environment_naming_info, dag=snapshot_dag ) self.console.start_evaluation_progress( @@ -399,7 +451,16 @@ def run_merged_intervals( audit_only=audit_only, ) - dag = self._dag(batched_intervals) + snapshots_to_create = { + s.snapshot_id + for s in self.snapshot_evaluator.get_snapshots_to_create( + selected_snapshots, deployability_index + ) + } + + dag = self._dag( + batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create + ) if run_environment_statements: environment_statements = self.state_sync.get_environment_statements( @@ -417,70 +478,81 @@ def run_merged_intervals( execution_time=execution_time, ) - def evaluate_node(node: SchedulingUnit) -> None: + def run_node(node: SchedulingUnit) -> None: if circuit_breaker and circuit_breaker(): raise CircuitBreakerError() - - snapshot_name, ((start, end), batch_idx) = node - if batch_idx == -1: + if isinstance(node, DummyNode): return - snapshot = self.snapshots_by_name[snapshot_name] - self.console.start_snapshot_evaluation_progress(snapshot) - - execution_start_ts = now_timestamp() - evaluation_duration_ms: t.Optional[int] = None - - audit_results: t.List[AuditResult] = [] - try: - assert execution_time # mypy - assert deployability_index # mypy - - if audit_only: - audit_results = self._audit_snapshot( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - deployability_index=deployability_index, - snapshots=self.snapshots_by_name, - start=start, - end=end, - execution_time=execution_time, - ) - else: - audit_results = self.evaluate( - snapshot=snapshot, - environment_naming_info=environment_naming_info, - start=start, - end=end, - execution_time=execution_time, - deployability_index=deployability_index, - batch_index=batch_idx, + snapshot = self.snapshots_by_name[node.snapshot_name] + + if isinstance(node, EvaluateNode): + self.console.start_snapshot_evaluation_progress(snapshot) + execution_start_ts = now_timestamp() + evaluation_duration_ms: t.Optional[int] = None + start, end = node.interval + + audit_results: t.List[AuditResult] = [] + try: + assert execution_time # mypy + assert deployability_index # mypy + + if audit_only: + audit_results = self._audit_snapshot( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, + snapshots=self.snapshots_by_name, + start=start, + end=end, + execution_time=execution_time, + ) + else: + audit_results = self.evaluate( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + start=start, + end=end, + execution_time=execution_time, + deployability_index=deployability_index, + batch_index=node.batch_index, + allow_destructive_snapshots=allow_destructive_snapshots, + target_table_exists=snapshot.snapshot_id not in snapshots_to_create, + ) + + evaluation_duration_ms = now_timestamp() - execution_start_ts + finally: + num_audits = len(audit_results) + num_audits_failed = sum(1 for result in audit_results if result.count) + self.console.update_snapshot_evaluation_progress( + snapshot, + batched_intervals[snapshot][node.batch_index], + node.batch_index, + evaluation_duration_ms, + num_audits - num_audits_failed, + num_audits_failed, ) - - evaluation_duration_ms = now_timestamp() - execution_start_ts - finally: - num_audits = len(audit_results) - num_audits_failed = sum(1 for result in audit_results if result.count) - self.console.update_snapshot_evaluation_progress( - snapshot, - batched_intervals[snapshot][batch_idx], - batch_idx, - evaluation_duration_ms, - num_audits - num_audits_failed, - num_audits_failed, + elif isinstance(node, CreateNode): + self.snapshot_evaluator.create_snapshot( + snapshot=snapshot, + snapshots=self.snapshots_by_name, + deployability_index=deployability_index, + allow_destructive_snapshots=allow_destructive_snapshots or set(), ) try: with self.snapshot_evaluator.concurrent_context(): errors, skipped_intervals = concurrent_apply_to_dag( dag, - evaluate_node, + run_node, self.max_workers, raise_on_error=False, ) self.console.stop_evaluation_progress(success=not errors) - skipped_snapshots = {i[0] for i in skipped_intervals} + skipped_snapshots = { + i.snapshot_name for i in skipped_intervals if isinstance(i, EvaluateNode) + } self.console.log_skipped_models(skipped_snapshots) for skipped in skipped_snapshots: logger.info(f"SKIPPED snapshot {skipped}\n") @@ -509,11 +581,18 @@ def evaluate_node(node: SchedulingUnit) -> None: self.state_sync.recycle() - def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]: + def _dag( + self, + batches: SnapshotToIntervals, + snapshot_dag: t.Optional[DAG[SnapshotId]] = None, + snapshots_to_create: t.Optional[t.Set[SnapshotId]] = None, + ) -> DAG[SchedulingUnit]: """Builds a DAG of snapshot intervals to be evaluated. Args: batches: The batches of snapshots and intervals to evaluate. + snapshot_dag: The DAG of all snapshots. + snapshots_to_create: The snapshots with missing physical tables. Returns: A DAG of snapshot intervals to be evaluated. @@ -522,46 +601,72 @@ def _dag(self, batches: SnapshotToIntervals) -> DAG[SchedulingUnit]: intervals_per_snapshot = { snapshot.name: intervals for snapshot, intervals in batches.items() } + snapshots_to_create = snapshots_to_create or set() + original_snapshots_to_create = snapshots_to_create.copy() + snapshot_dag = snapshot_dag or snapshots_to_dag(batches) dag = DAG[SchedulingUnit]() - terminal_node = ((to_timestamp(0), to_timestamp(0)), -1) - for snapshot, intervals in batches.items(): - if not intervals: - continue + for snapshot_id in snapshot_dag: + snapshot = self.snapshots_by_name[snapshot_id.name] + intervals = intervals_per_snapshot.get(snapshot.name, []) - upstream_dependencies = [] + upstream_dependencies: t.List[SchedulingUnit] = [] for p_sid in snapshot.parents: if p_sid in self.snapshots: p_intervals = intervals_per_snapshot.get(p_sid.name, []) - if len(p_intervals) > 1: - upstream_dependencies.append((p_sid.name, terminal_node)) + if not p_intervals and p_sid in original_snapshots_to_create: + upstream_dependencies.append(CreateNode(snapshot_name=p_sid.name)) + elif len(p_intervals) > 1: + upstream_dependencies.append(DummyNode(snapshot_name=p_sid.name)) else: for i, interval in enumerate(p_intervals): - upstream_dependencies.append((p_sid.name, (interval, i))) + upstream_dependencies.append( + EvaluateNode( + snapshot_name=p_sid.name, interval=interval, batch_index=i + ) + ) batch_concurrency = snapshot.node.batch_concurrency + batch_size = snapshot.node.batch_size if snapshot.depends_on_past: batch_concurrency = 1 + create_node: t.Optional[CreateNode] = None + if snapshot.snapshot_id in original_snapshots_to_create and ( + snapshot.is_incremental_by_time_range + or ((not batch_concurrency or batch_concurrency > 1) and batch_size) + or not intervals + ): + # Add a separate node for table creation in case when there multiple concurrent + # evaluation nodes or when there are no intervals to evaluate. + create_node = CreateNode(snapshot_name=snapshot.name) + dag.add(create_node, upstream_dependencies) + snapshots_to_create.remove(snapshot.snapshot_id) + for i, interval in enumerate(intervals): - node = (snapshot.name, (interval, i)) - dag.add(node, upstream_dependencies) + node = EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=i) + + if create_node: + dag.add(node, [create_node]) + else: + dag.add(node, upstream_dependencies) if len(intervals) > 1: - dag.add((snapshot.name, terminal_node), [node]) + dag.add(DummyNode(snapshot_name=snapshot.name), [node]) if batch_concurrency and i >= batch_concurrency: batch_idx_to_wait_for = i - batch_concurrency dag.add( node, [ - ( - snapshot.name, - (intervals[batch_idx_to_wait_for], batch_idx_to_wait_for), - ) + EvaluateNode( + snapshot_name=snapshot.name, + interval=intervals[batch_idx_to_wait_for], + batch_index=batch_idx_to_wait_for, + ), ], ) return dag diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 996d539e60..ec5a883f7f 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -1043,8 +1043,15 @@ def categorize_as(self, category: SnapshotChangeCategory, forward_only: bool = F # If the model has a pinned version then use that. self.version = self.model.physical_version elif is_no_rebuild and self.previous_version: + self.version = self.previous_version.data_version.version + elif self.is_model and self.model.forward_only and not self.previous_version: + # If this is a new model then use a deterministic version, independent of the fingerprint. + self.version = hash_data([self.name, *self.model.kind.data_hash_values]) + else: + self.version = self.fingerprint.to_version() + + if is_no_rebuild and self.previous_version: previous_version = self.previous_version - self.version = previous_version.data_version.version self.physical_schema_ = previous_version.physical_schema self.table_naming_convention = previous_version.table_naming_convention if self.is_materialized and (category.is_indirect_non_breaking or category.is_metadata): @@ -1054,11 +1061,6 @@ def categorize_as(self, category: SnapshotChangeCategory, forward_only: bool = F or previous_version.fingerprint.to_version() ) self.dev_table_suffix = previous_version.data_version.dev_table_suffix - elif self.is_model and self.model.forward_only and not self.previous_version: - # If this is a new model then use a deterministic version, independent of the fingerprint. - self.version = hash_data([self.name, *self.model.kind.data_hash_values]) - else: - self.version = self.fingerprint.to_version() self.change_category = category self.forward_only = forward_only @@ -1383,12 +1385,11 @@ def requires_schema_migration_in_prod(self) -> bool: return ( self.is_paused and self.is_model - and not self.is_symbolic + and self.is_materialized and ( (self.previous_version and self.previous_version.version == self.version) or self.model.forward_only or bool(self.model.physical_version) - or self.is_view or not self.virtual_environment_mode.is_full ) ) @@ -1588,7 +1589,9 @@ def create( # Similarly, if the model depends on past and the start date is not aligned with the # model's start, we should consider this snapshot non-deployable. this_deployable = False - if not snapshot.is_paused or snapshot.is_indirect_non_breaking: + if not snapshot.is_paused or ( + snapshot.is_indirect_non_breaking and snapshot.intervals + ): # This snapshot represents what's currently deployed in prod. representative_shared_version_ids.add(node) diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index f2f7044ba7..1531997c1b 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -60,7 +60,6 @@ SnapshotInfoLike, SnapshotTableCleanupTask, ) -from sqlmesh.core.snapshot.definition import parent_snapshots_by_name from sqlmesh.utils import random_id, CorrelationId from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, @@ -138,8 +137,10 @@ def evaluate( end: TimeLike, execution_time: TimeLike, snapshots: t.Dict[str, Snapshot], + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, batch_index: int = 0, + target_table_exists: t.Optional[bool] = None, **kwargs: t.Any, ) -> t.Optional[str]: """Renders the snapshot's model, executes it and stores the result in the snapshot's physical table. @@ -150,21 +151,25 @@ def evaluate( end: The end datetime to render. execution_time: The date/time time reference to use for execution time. snapshots: All upstream snapshots (by name) to use for expansion and mapping of physical locations. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. deployability_index: Determines snapshots that are deployable in the context of this evaluation. batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it + target_table_exists: Whether the target table exists. If None, the table will be checked for existence. kwargs: Additional kwargs to pass to the renderer. Returns: The WAP ID of this evaluation if supported, None otherwise. """ result = self._evaluate_snapshot( - snapshot, - start, - end, - execution_time, - snapshots, + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots or set(), deployability_index=deployability_index, batch_index=batch_index, + target_table_exists=target_table_exists, **kwargs, ) if result is None or isinstance(result, str): @@ -200,21 +205,40 @@ def evaluate_and_fetch( Returns: The result of the evaluation as a dataframe. """ - result = self._evaluate_snapshot( + import pandas as pd + + adapter = self.get_adapter(snapshot.model.gateway) + render_kwargs = dict( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + runtime_stage=RuntimeStage.EVALUATING, + **kwargs, + ) + queries_or_dfs = self._render_snapshot_for_evaluation( snapshot, - start, - end, - execution_time, snapshots, - limit=limit, - deployability_index=deployability_index, - **kwargs, + deployability_index or DeployabilityIndex.all_deployable(), + render_kwargs, ) - if result is None or isinstance(result, str): - raise SQLMeshError( - f"Unexpected result {result} when evaluating snapshot {snapshot.snapshot_id}." - ) - return result + query_or_df = next(queries_or_dfs) + if isinstance(query_or_df, pd.DataFrame): + return query_or_df.head(limit) + if not isinstance(query_or_df, exp.Expression): + # We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark / bigframe dataframe, + # so we use `limit` instead of `head` to get back a dataframe instead of List[Row] + # https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head + return query_or_df.limit(limit) + + assert isinstance(query_or_df, exp.Query) + + existing_limit = query_or_df.args.get("limit") + if existing_limit: + limit = min(limit, execute(exp.select(existing_limit.expression)).rows[0][0]) + assert limit is not None + + return adapter._fetch_native_df(query_or_df.limit(limit)) def promote( self, @@ -254,7 +278,11 @@ def promote( for gateway, tables in tables_by_gateway.items(): if environment_naming_info.suffix_target.is_catalog: self._create_catalogs(tables=tables, gateway=gateway) - self._create_schemas(tables=tables, gateway=gateway) + + gateway_table_pairs = [ + (gateway, table) for gateway, tables in tables_by_gateway.items() for table in tables + ] + self._create_schemas(gateway_table_pairs=gateway_table_pairs) deployability_index = deployability_index or DeployabilityIndex.all_deployable() with self.concurrent_context(): @@ -324,32 +352,66 @@ def create( Returns: CompletionStatus: The status of the creation operation (success, failure, nothing to do). """ + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + + snapshots_to_create = self.get_snapshots_to_create(target_snapshots, deployability_index) + if not snapshots_to_create: + return CompletionStatus.NOTHING_TO_DO + if on_start: + on_start(snapshots_to_create) + + self._create_snapshots( + snapshots_to_create=snapshots_to_create, + snapshots={s.name: s for s in snapshots.values()}, + deployability_index=deployability_index, + on_complete=on_complete, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + ) + return CompletionStatus.SUCCESS + + def create_physical_schemas( + self, snapshots: t.Iterable[Snapshot], deployability_index: DeployabilityIndex + ) -> None: + """Creates the physical schemas for the given snapshots. + + Args: + snapshots: Snapshots to create physical schemas for. + deployability_index: Determines snapshots that are deployable in the context of this creation. + """ + tables_by_gateway: t.Dict[t.Optional[str], t.List[str]] = defaultdict(list) + for snapshot in snapshots: + if snapshot.is_model and not snapshot.is_symbolic: + tables_by_gateway[snapshot.model_gateway].append( + snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) + ) + + gateway_table_pairs = [ + (gateway, table) for gateway, tables in tables_by_gateway.items() for table in tables + ] + self._create_schemas(gateway_table_pairs=gateway_table_pairs) + + def get_snapshots_to_create( + self, target_snapshots: t.Iterable[Snapshot], deployability_index: DeployabilityIndex + ) -> t.List[Snapshot]: + """Returns a list of snapshots that need to have their physical tables created. + + Args: + target_snapshots: Target snapshots. + deployability_index: Determines snapshots that are deployable / representative in the context of this creation. + """ snapshots_with_table_names = defaultdict(set) tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = ( defaultdict(lambda: defaultdict(set)) ) - table_deployability: t.Dict[str, bool] = {} - allow_destructive_snapshots = allow_destructive_snapshots or set() for snapshot in target_snapshots: if not snapshot.is_model or snapshot.is_symbolic: continue - deployability_flags = [True] - if ( - snapshot.is_no_rebuild - or snapshot.is_managed - or (snapshot.is_model and snapshot.model.forward_only) - or (deployability_index and not deployability_index.is_deployable(snapshot)) - ): - deployability_flags.append(False) - for is_deployable in deployability_flags: - table = exp.to_table( - snapshot.table_name(is_deployable), dialect=snapshot.model.dialect - ) - snapshots_with_table_names[snapshot].add(table.name) - table_deployability[table.name] = is_deployable - table_schema = d.schema_(table.db, catalog=table.catalog) - tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name) + is_deployable = deployability_index.is_deployable(snapshot) + table = exp.to_table(snapshot.table_name(is_deployable), dialect=snapshot.model.dialect) + snapshots_with_table_names[snapshot].add(table.name) + table_schema = d.schema_(table.db, catalog=table.catalog) + tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name) def _get_data_objects( schema: exp.Table, @@ -378,41 +440,18 @@ def _get_data_objects( existing_objects.update(objs_for_gateway) snapshots_to_create = [] - target_deployability_flags: t.Dict[str, t.List[bool]] = defaultdict(list) for snapshot, table_names in snapshots_with_table_names.items(): missing_tables = table_names - existing_objects if missing_tables or (snapshot.is_seed and not snapshot.intervals): snapshots_to_create.append(snapshot) - for table_name in missing_tables or table_names: - target_deployability_flags[snapshot.name].append( - table_deployability[table_name] - ) - target_deployability_flags[snapshot.name].sort() - - if not snapshots_to_create: - return CompletionStatus.NOTHING_TO_DO - if on_start: - on_start(snapshots_to_create) - for gateway, tables_by_schema in tables_by_gateway_and_schema.items(): - self._create_schemas(tables=tables_by_schema, gateway=gateway) - - self._create_snapshots( - snapshots_to_create=snapshots_to_create, - snapshots=snapshots, - target_deployability_flags=target_deployability_flags, - deployability_index=deployability_index, - on_complete=on_complete, - allow_destructive_snapshots=allow_destructive_snapshots, - ) - return CompletionStatus.SUCCESS + return snapshots_to_create def _create_snapshots( self, snapshots_to_create: t.Iterable[Snapshot], - snapshots: t.Dict[SnapshotId, Snapshot], - target_deployability_flags: t.Dict[str, t.List[bool]], - deployability_index: t.Optional[DeployabilityIndex], + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], allow_destructive_snapshots: t.Set[str], ) -> None: @@ -420,13 +459,12 @@ def _create_snapshots( with self.concurrent_context(): errors, skipped = concurrent_apply_to_snapshots( snapshots_to_create, - lambda s: self._create_snapshot( + lambda s: self.create_snapshot( s, snapshots=snapshots, - deployability_flags=target_deployability_flags[s.name], deployability_index=deployability_index, - on_complete=on_complete, allow_destructive_snapshots=allow_destructive_snapshots, + on_complete=on_complete, ), self.ddl_concurrent_tasks, raise_on_error=False, @@ -451,12 +489,13 @@ def migrate( """ allow_destructive_snapshots = allow_destructive_snapshots or set() deployability_index = deployability_index or DeployabilityIndex.all_deployable() + snapshots_by_name = {s.name: s for s in snapshots.values()} with self.concurrent_context(): concurrent_apply_to_snapshots( target_snapshots, lambda s: self._migrate_snapshot( s, - snapshots, + snapshots_by_name, allow_destructive_snapshots, self.get_adapter(s.model_gateway), deployability_index, @@ -612,18 +651,29 @@ def close(self) -> None: except Exception: logger.exception("Failed to close Snapshot Evaluator") + def set_correlation_id(self, correlation_id: CorrelationId) -> SnapshotEvaluator: + return SnapshotEvaluator( + { + gateway: adapter.with_settings(correlation_id=correlation_id) + for gateway, adapter in self.adapters.items() + }, + self.ddl_concurrent_tasks, + self.selected_gateway, + ) + def _evaluate_snapshot( self, - snapshot: Snapshot, start: TimeLike, end: TimeLike, execution_time: TimeLike, + snapshot: Snapshot, snapshots: t.Dict[str, Snapshot], - limit: t.Optional[int] = None, - deployability_index: t.Optional[DeployabilityIndex] = None, - batch_index: int = 0, + allow_destructive_snapshots: t.Set[str], + deployability_index: t.Optional[DeployabilityIndex], + batch_index: int, + target_table_exists: t.Optional[bool], **kwargs: t.Any, - ) -> DF | str | None: + ) -> t.Optional[str]: """Renders the snapshot's model and executes it. The return value depends on whether the limit was specified. Args: @@ -632,54 +682,206 @@ def _evaluate_snapshot( end: The end datetime to render. execution_time: The date/time time reference to use for execution time. snapshots: All upstream snapshots to use for expansion and mapping of physical locations. - limit: If limit is not None, the query will not be persisted but evaluated and returned as a dataframe. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. deployability_index: Determines snapshots that are deployable in the context of this evaluation. batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it + target_table_exists: Whether the target table exists. If None, the table will be checked for existence. kwargs: Additional kwargs to pass to the renderer. """ - if not snapshot.is_model or snapshot.is_seed: + if not snapshot.is_model: return None model = snapshot.model logger.info("Evaluating snapshot %s", snapshot.snapshot_id) - deployability_index = deployability_index or DeployabilityIndex.all_deployable() - table_name = ( - "" - if limit is not None - else snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) - ) - adapter = self.get_adapter(model.gateway) - evaluation_strategy = _evaluation_strategy(snapshot, adapter) - + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + is_snapshot_deployable = deployability_index.is_deployable(snapshot) + target_table_name = snapshot.table_name(is_deployable=is_snapshot_deployable) # https://github.com/TobikoData/sqlmesh/issues/2609 # If there are no existing intervals yet; only consider this a first insert for the first snapshot in the batch - is_first_insert = not _intervals(snapshot, deployability_index) and batch_index == 0 - - from sqlmesh.core.context import ExecutionContext - + if target_table_exists is None: + target_table_exists = adapter.table_exists(target_table_name) + is_first_insert = ( + not _intervals(snapshot, deployability_index) or not target_table_exists + ) and batch_index == 0 + + # Use the 'creating' stage if the table doesn't exist yet to preserve backwards compatibility with existing projects + # that depend on a separate physical table creation stage. + runtime_stage = RuntimeStage.EVALUATING if target_table_exists else RuntimeStage.CREATING common_render_kwargs = dict( start=start, end=end, execution_time=execution_time, snapshot=snapshot, - runtime_stage=RuntimeStage.EVALUATING, + runtime_stage=runtime_stage, **kwargs, ) - + create_render_kwargs = dict( + engine_adapter=adapter, + snapshots=snapshots, + deployability_index=deployability_index, + **common_render_kwargs, + ) + create_render_kwargs["runtime_stage"] = RuntimeStage.CREATING render_statements_kwargs = dict( engine_adapter=adapter, snapshots=snapshots, deployability_index=deployability_index, **common_render_kwargs, ) - rendered_physical_properties = snapshot.model.render_physical_properties( **render_statements_kwargs ) + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)), + ): + adapter.execute(model.render_pre_statements(**render_statements_kwargs)) + + if not target_table_exists or (model.is_seed and not snapshot.intervals): + if self._can_clone(snapshot, deployability_index): + self._clone_snapshot_in_dev( + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + allow_destructive_snapshots=allow_destructive_snapshots, + ) + common_render_kwargs["runtime_stage"] = RuntimeStage.EVALUATING + elif model.annotated or model.is_seed or model.kind.is_scd_type_2: + self._execute_create( + snapshot=snapshot, + table_name=target_table_name, + is_table_deployable=is_snapshot_deployable, + deployability_index=deployability_index, + create_render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + dry_run=False, + run_pre_post_statements=False, + ) + common_render_kwargs["runtime_stage"] = RuntimeStage.EVALUATING + + wap_id: t.Optional[str] = None + if snapshot.is_materialized and ( + model.wap_supported or adapter.wap_supported(target_table_name) + ): + wap_id = random_id()[0:8] + logger.info("Using WAP ID '%s' for snapshot %s", wap_id, snapshot.snapshot_id) + target_table_name = adapter.wap_prepare(target_table_name, wap_id) + + self._render_and_insert_snapshot( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + snapshots=snapshots, + render_kwargs=common_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + deployability_index=deployability_index, + is_first_insert=is_first_insert, + batch_index=batch_index, + ) + + adapter.execute(model.render_post_statements(**render_statements_kwargs)) + + return wap_id + + def create_snapshot( + self, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + allow_destructive_snapshots: t.Set[str], + on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, + ) -> None: + """Creates a physical table for the given snapshot. + + Args: + snapshot: Snapshot to create. + snapshots: All upstream snapshots to use for expansion and mapping of physical locations. + deployability_index: Determines snapshots that are deployable in the context of this creation. + on_complete: A callback to call on each successfully created database object. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. + """ + if not snapshot.is_model: + return + + logger.info("Creating a physical table for snapshot %s", snapshot.snapshot_id) + + adapter = self.get_adapter(snapshot.model.gateway) + create_render_kwargs: t.Dict[str, t.Any] = dict( + engine_adapter=adapter, + snapshots=snapshots, + runtime_stage=RuntimeStage.CREATING, + deployability_index=deployability_index, + ) + + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)), + ): + rendered_physical_properties = snapshot.model.render_physical_properties( + **create_render_kwargs + ) + + if self._can_clone(snapshot, deployability_index): + self._clone_snapshot_in_dev( + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + allow_destructive_snapshots=allow_destructive_snapshots, + ) + else: + is_table_deployable = deployability_index.is_deployable(snapshot) + self._execute_create( + snapshot=snapshot, + table_name=snapshot.table_name(is_deployable=is_table_deployable), + is_table_deployable=is_table_deployable, + deployability_index=deployability_index, + create_render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + dry_run=True, + ) + + if on_complete is not None: + on_complete(snapshot) + + def _render_and_insert_snapshot( + self, + start: TimeLike, + end: TimeLike, + execution_time: TimeLike, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + render_kwargs: t.Dict[str, t.Any], + rendered_physical_properties: t.Dict[str, exp.Expression], + deployability_index: DeployabilityIndex, + is_first_insert: bool, + batch_index: int, + ) -> None: + if not snapshot.is_model or snapshot.is_seed: + return + + logger.info("Inserting data for snapshot %s", snapshot.snapshot_id) + + model = snapshot.model + table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) + adapter = self.get_adapter(model.gateway) + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + + queries_or_dfs = self._render_snapshot_for_evaluation( + snapshot, + snapshots, + deployability_index, + render_kwargs, + ) + def apply(query_or_df: QueryOrDF, index: int = 0) -> None: if index > 0: evaluation_strategy.append( @@ -694,7 +896,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: end=end, execution_time=execution_time, physical_properties=rendered_physical_properties, - render_kwargs=render_statements_kwargs, + render_kwargs=render_kwargs, ) else: logger.info( @@ -716,203 +918,101 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: end=end, execution_time=execution_time, physical_properties=rendered_physical_properties, - render_kwargs=render_statements_kwargs, + render_kwargs=render_kwargs, ) - with ( - adapter.transaction(), - adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)), + # DataFrames, unlike SQL expressions, can provide partial results by yielding dataframes. As a result, + # if the engine supports INSERT OVERWRITE or REPLACE WHERE and the snapshot is incremental by time range, we risk + # having a partial result since each dataframe write can re-truncate partitions. To avoid this, we + # union all the dataframes together before writing. For pandas this could result in OOM and a potential + # workaround for that would be to serialize pandas to disk and then read it back with Spark. + # Note: We assume that if multiple things are yielded from `queries_or_dfs` that they are dataframes + # and not SQL expressions. + if ( + adapter.INSERT_OVERWRITE_STRATEGY + in ( + InsertOverwriteStrategy.INSERT_OVERWRITE, + InsertOverwriteStrategy.REPLACE_WHERE, + ) + and snapshot.is_incremental_by_time_range ): - wap_id: t.Optional[str] = None - if ( - table_name - and snapshot.is_materialized - and (model.wap_supported or adapter.wap_supported(table_name)) - ): - wap_id = random_id()[0:8] - logger.info("Using WAP ID '%s' for snapshot %s", wap_id, snapshot.snapshot_id) - table_name = adapter.wap_prepare(table_name, wap_id) - - if limit is None: - adapter.execute(model.render_pre_statements(**render_statements_kwargs)) - - queries_or_dfs = model.render( - context=ExecutionContext( - adapter, - snapshots, - deployability_index, - default_dialect=model.dialect, - default_catalog=model.default_catalog, - ), - **common_render_kwargs, + import pandas as pd + + query_or_df = reduce( + lambda a, b: ( + pd.concat([a, b], ignore_index=True) # type: ignore + if isinstance(a, pd.DataFrame) + else a.union_all(b) # type: ignore + ), # type: ignore + queries_or_dfs, ) + apply(query_or_df, index=0) + else: + for index, query_or_df in enumerate(queries_or_dfs): + apply(query_or_df, index) - if limit is not None: - import pandas as pd - - query_or_df = next(queries_or_dfs) - if isinstance(query_or_df, pd.DataFrame): - return query_or_df.head(limit) - if not isinstance(query_or_df, exp.Expression): - # We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark / bigframe dataframe, - # so we use `limit` instead of `head` to get back a dataframe instead of List[Row] - # https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head - return query_or_df.limit(limit) - - assert isinstance(query_or_df, exp.Query) - - existing_limit = query_or_df.args.get("limit") - if existing_limit: - limit = min(limit, execute(exp.select(existing_limit.expression)).rows[0][0]) - assert limit is not None - - return adapter._fetch_native_df(query_or_df.limit(limit)) - - # DataFrames, unlike SQL expressions, can provide partial results by yielding dataframes. As a result, - # if the engine supports INSERT OVERWRITE or REPLACE WHERE and the snapshot is incremental by time range, we risk - # having a partial result since each dataframe write can re-truncate partitions. To avoid this, we - # union all the dataframes together before writing. For pandas this could result in OOM and a potential - # workaround for that would be to serialize pandas to disk and then read it back with Spark. - # Note: We assume that if multiple things are yielded from `queries_or_dfs` that they are dataframes - # and not SQL expressions. - if ( - adapter.INSERT_OVERWRITE_STRATEGY - in ( - InsertOverwriteStrategy.INSERT_OVERWRITE, - InsertOverwriteStrategy.REPLACE_WHERE, - ) - and snapshot.is_incremental_by_time_range - ): - import pandas as pd - - query_or_df = reduce( - lambda a, b: ( - pd.concat([a, b], ignore_index=True) # type: ignore - if isinstance(a, pd.DataFrame) - else a.union_all(b) # type: ignore - ), # type: ignore - queries_or_dfs, - ) - apply(query_or_df, index=0) - else: - for index, query_or_df in enumerate(queries_or_dfs): - apply(query_or_df, index) + def _render_snapshot_for_evaluation( + self, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + render_kwargs: t.Dict[str, t.Any], + ) -> t.Iterator[QueryOrDF]: + from sqlmesh.core.context import ExecutionContext - if limit is None: - adapter.execute(model.render_post_statements(**render_statements_kwargs)) + model = snapshot.model + adapter = self.get_adapter(model.gateway) - return wap_id + return model.render( + context=ExecutionContext( + adapter, + snapshots, + deployability_index, + default_dialect=model.dialect, + default_catalog=model.default_catalog, + ), + **render_kwargs, + ) - def _create_snapshot( + def _clone_snapshot_in_dev( self, snapshot: Snapshot, - snapshots: t.Dict[SnapshotId, Snapshot], - deployability_flags: t.List[bool], - deployability_index: t.Optional[DeployabilityIndex], - on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + render_kwargs: t.Dict[str, t.Any], + rendered_physical_properties: t.Dict[str, exp.Expression], allow_destructive_snapshots: t.Set[str], ) -> None: - if not snapshot.is_model: - return - - deployability_index = deployability_index or DeployabilityIndex.all_deployable() - adapter = self.get_adapter(snapshot.model.gateway) - create_render_kwargs: t.Dict[str, t.Any] = dict( - engine_adapter=adapter, - snapshots=parent_snapshots_by_name(snapshot, snapshots), - runtime_stage=RuntimeStage.CREATING, - deployability_index=deployability_index, - ) - - with ( - adapter.transaction(), - adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)), - ): - rendered_physical_properties = snapshot.model.render_physical_properties( - **create_render_kwargs - ) - - if ( - snapshot.is_forward_only - and snapshot.is_materialized - and snapshot.previous_versions - and adapter.SUPPORTS_CLONING - # managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work - and not snapshot.is_managed - # If the deployable table is missing we can't clone it - and True not in deployability_flags - ): - target_table_name = snapshot.table_name(is_deployable=False) - tmp_table_name = f"{target_table_name}__schema_migration_source" - source_table_name = snapshot.table_name() - - logger.info(f"Cloning table '{source_table_name}' into '{target_table_name}'") - self._execute_create( - snapshot=snapshot, - table_name=tmp_table_name, - is_table_deployable=False, - deployability_index=deployability_index, - create_render_kwargs=create_render_kwargs, - rendered_physical_properties=rendered_physical_properties, - dry_run=True, - ) + target_table_name = snapshot.table_name(is_deployable=False) + source_table_name = snapshot.table_name() - try: - adapter.clone_table( - target_table_name, - snapshot.table_name(), - replace=True, - rendered_physical_properties=rendered_physical_properties, - ) - alter_expressions = adapter.get_alter_expressions( - target_table_name, - tmp_table_name, - ignore_destructive=snapshot.model.on_destructive_change.is_ignore, - ) - _check_destructive_schema_change( - snapshot, alter_expressions, allow_destructive_snapshots - ) - adapter.alter_table(alter_expressions) - except Exception: - adapter.drop_table(target_table_name) - raise - finally: - adapter.drop_table(tmp_table_name) - else: - dry_run = len(deployability_flags) == 1 - for is_table_deployable in deployability_flags: - if ( - is_table_deployable - and snapshot.model.forward_only - and not deployability_index.is_representative(snapshot) - ): - logger.info( - "Skipping creation of the deployable table '%s' for the forward-only model %s. " - "The table will be created when the snapshot is deployed to production", - snapshot.table_name(is_deployable=is_table_deployable), - snapshot.snapshot_id, - ) - continue - - self._execute_create( - snapshot=snapshot, - table_name=snapshot.table_name(is_deployable=is_table_deployable), - is_table_deployable=is_table_deployable, - deployability_index=deployability_index, - create_render_kwargs=create_render_kwargs, - rendered_physical_properties=rendered_physical_properties, - dry_run=dry_run, - ) - - if on_complete is not None: - on_complete(snapshot) + try: + logger.info(f"Cloning table '{source_table_name}' into '{target_table_name}'") + adapter.clone_table( + target_table_name, + snapshot.table_name(), + replace=True, + rendered_physical_properties=rendered_physical_properties, + ) + self._migrate_target_table( + target_table_name=target_table_name, + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + render_kwargs=render_kwargs, + rendered_physical_properties=rendered_physical_properties, + allow_destructive_snapshots=allow_destructive_snapshots, + ) + except Exception: + adapter.drop_table(target_table_name) + raise def _migrate_snapshot( self, snapshot: Snapshot, - snapshots: t.Dict[SnapshotId, Snapshot], + snapshots: t.Dict[str, Snapshot], allow_destructive_snapshots: t.Set[str], adapter: EngineAdapter, deployability_index: DeployabilityIndex, @@ -923,7 +1023,7 @@ def _migrate_snapshot( deployability_index = DeployabilityIndex.all_deployable() render_kwargs: t.Dict[str, t.Any] = dict( engine_adapter=adapter, - snapshots=parent_snapshots_by_name(snapshot, snapshots), + snapshots=snapshots, runtime_stage=RuntimeStage.CREATING, deployability_index=deployability_index, ) @@ -941,39 +1041,63 @@ def _migrate_snapshot( table_exists = False if table_exists: - evaluation_strategy = _evaluation_strategy(snapshot, adapter) - tmp_table_name = snapshot.table_name(is_deployable=False) - logger.info( - "Migrating table schema '%s' to match '%s'", - target_table_name, - tmp_table_name, - ) - evaluation_strategy.migrate( + self._migrate_target_table( target_table_name=target_table_name, - source_table_name=tmp_table_name, snapshot=snapshot, - snapshots=parent_snapshots_by_name(snapshot, snapshots), - allow_destructive_snapshots=allow_destructive_snapshots, - ignore_destructive=snapshot.model.on_destructive_change.is_ignore, - ) - else: - logger.info( - "Creating table '%s' for the snapshot of the forward-only model %s", - target_table_name, - snapshot.snapshot_id, - ) - self._execute_create( - snapshot=snapshot, - table_name=target_table_name, - is_table_deployable=True, + snapshots=snapshots, deployability_index=deployability_index, - create_render_kwargs=render_kwargs, + render_kwargs=render_kwargs, rendered_physical_properties=snapshot.model.render_physical_properties( **render_kwargs ), - dry_run=False, + allow_destructive_snapshots=allow_destructive_snapshots, + run_pre_post_statements=True, ) + def _migrate_target_table( + self, + target_table_name: str, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + render_kwargs: t.Dict[str, t.Any], + rendered_physical_properties: t.Dict[str, exp.Expression], + allow_destructive_snapshots: t.Set[str], + run_pre_post_statements: bool = False, + ) -> None: + adapter = self.get_adapter(snapshot.model.gateway) + + tmp_table_name = f"{target_table_name}_schema_tmp" + if snapshot.is_materialized: + self._execute_create( + snapshot=snapshot, + table_name=tmp_table_name, + is_table_deployable=False, + deployability_index=deployability_index, + create_render_kwargs=render_kwargs, + rendered_physical_properties=rendered_physical_properties, + dry_run=False, + run_pre_post_statements=run_pre_post_statements, + ) + try: + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + logger.info( + "Migrating table schema from '%s' to '%s'", + tmp_table_name, + target_table_name, + ) + evaluation_strategy.migrate( + target_table_name=target_table_name, + source_table_name=tmp_table_name, + snapshot=snapshot, + snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots, + ignore_destructive=snapshot.model.on_destructive_change.is_ignore, + ) + finally: + if snapshot.is_materialized: + adapter.drop_table(tmp_table_name) + def _promote_snapshot( self, snapshot: Snapshot, @@ -1182,19 +1306,30 @@ def _create_catalogs( def _create_schemas( self, - tables: t.Iterable[t.Union[exp.Table, str]], - gateway: t.Optional[str] = None, + gateway_table_pairs: t.Iterable[t.Tuple[t.Optional[str], t.Union[exp.Table, str]]], ) -> None: - table_exprs = [exp.to_table(t) for t in tables] - unique_schemas = {(t.args["db"], t.args.get("catalog")) for t in table_exprs if t and t.db} - # Create schemas sequentially, since some engines (eg. Postgres) may not support concurrent creation - # of schemas with the same name. - for schema_name, catalog in unique_schemas: + table_exprs = [(gateway, exp.to_table(t)) for gateway, t in gateway_table_pairs] + unique_schemas = { + (gateway, t.args["db"], t.args.get("catalog")) + for gateway, t in table_exprs + if t and t.db + } + + def _create_schema( + gateway: t.Optional[str], schema_name: str, catalog: t.Optional[str] + ) -> None: schema = schema_(schema_name, catalog) logger.info("Creating schema '%s'", schema) adapter = self.get_adapter(gateway) adapter.create_schema(schema) + with self.concurrent_context(): + concurrent_apply_to_values( + list(unique_schemas), + lambda item: _create_schema(item[0], item[1], item[2]), + self.ddl_concurrent_tasks, + ) + def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: """Returns the adapter for the specified gateway or the default adapter if none is provided.""" if gateway: @@ -1212,6 +1347,7 @@ def _execute_create( create_render_kwargs: t.Dict[str, t.Any], rendered_physical_properties: t.Dict[str, exp.Expression], dry_run: bool, + run_pre_post_statements: bool = True, ) -> None: adapter = self.get_adapter(snapshot.model.gateway) evaluation_strategy = _evaluation_strategy(snapshot, adapter) @@ -1224,7 +1360,8 @@ def _execute_create( **create_render_kwargs, "table_mapping": {snapshot.name: table_name}, } - adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs)) + if run_pre_post_statements: + adapter.execute(snapshot.model.render_pre_statements(**create_render_kwargs)) evaluation_strategy.create( table_name=table_name, model=snapshot.model, @@ -1235,16 +1372,20 @@ def _execute_create( dry_run=dry_run, physical_properties=rendered_physical_properties, ) - adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs)) + if run_pre_post_statements: + adapter.execute(snapshot.model.render_post_statements(**create_render_kwargs)) - def set_correlation_id(self, correlation_id: CorrelationId) -> SnapshotEvaluator: - return SnapshotEvaluator( - { - gateway: adapter.with_settings(correlation_id=correlation_id) - for gateway, adapter in self.adapters.items() - }, - self.ddl_concurrent_tasks, - self.selected_gateway, + def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool: + adapter = self.get_adapter(snapshot.model.gateway) + return ( + snapshot.is_forward_only + and snapshot.is_materialized + and bool(snapshot.previous_versions) + and adapter.SUPPORTS_CLONING + # managed models cannot have their schema mutated because theyre based on queries, so clone + alter wont work + and not snapshot.is_managed + # If the deployable table is missing we can't clone it + and not deployability_index.is_deployable(snapshot) ) @@ -1607,10 +1748,13 @@ def _replace_query_for_model( columns_to_types = model.columns_to_types_or_raise source_columns: t.Optional[t.List[str]] = list(columns_to_types) else: - # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. - columns_to_types, source_columns = self._get_target_and_source_columns( - model, name, render_kwargs, force_get_columns_from_target=True - ) + try: + # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. + columns_to_types, source_columns = self._get_target_and_source_columns( + model, name, render_kwargs, force_get_columns_from_target=True + ) + except Exception: + columns_to_types, source_columns = None, None self.adapter.replace_query( name, @@ -2087,11 +2231,7 @@ def insert( render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - deployability_index = ( - kwargs.get("deployability_index") or DeployabilityIndex.all_deployable() - ) snapshot = kwargs["snapshot"] - snapshots = kwargs["snapshots"] if ( not snapshot.is_materialized_view @@ -2131,17 +2271,6 @@ def create( render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - is_snapshot_representative: bool = kwargs["is_snapshot_representative"] - if not is_snapshot_representative and is_table_deployable: - # If the snapshot is not representative, the query may contain references to non-deployable tables or views. - # This may happen if there was a forward-only change upstream which now requires the view query to point at dev preview tables. - # Therefore, we postpone the creation of the deployable view until the snapshot is deployed to production. - logger.info( - "Skipping creation of the deployable view '%s' for the non-representative snapshot", - table_name, - ) - return - if self.adapter.table_exists(table_name): # Make sure we don't recreate the view to prevent deletion of downstream views in engines with no late # binding support (because of DROP CASCADE). diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index b4e8caf0bc..4bfe78cca0 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -57,9 +57,9 @@ def sqlmesh_config( if model_defaults.dialect is None: model_defaults.dialect = profile.target.dialect - target_to_sqlmesh_args = {} - if register_comments is not None: - target_to_sqlmesh_args["register_comments"] = register_comments + target_to_sqlmesh_args = { + "register_comments": register_comments or False, + } loader = kwargs.pop("loader", DbtLoader) if not issubclass(loader, DbtLoader): diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 45accccaa8..f283680cfb 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -139,10 +139,6 @@ def assert_new_env(result, new_env="prod", from_env="prod", initialize=True) -> ) in result.output -def assert_physical_layer_updated(result) -> None: - assert "Physical layer updated" in result.output - - def assert_model_batches_executed(result) -> None: assert "Model batches executed" in result.output @@ -152,7 +148,6 @@ def assert_virtual_layer_updated(result) -> None: def assert_backfill_success(result) -> None: - assert_physical_layer_updated(result) assert_model_batches_executed(result) assert_virtual_layer_updated(result) diff --git a/tests/cli/test_integration_cli.py b/tests/cli/test_integration_cli.py index 9b39b948f2..5d000b9d8b 100644 --- a/tests/cli/test_integration_cli.py +++ b/tests/cli/test_integration_cli.py @@ -141,7 +141,6 @@ def do_something(evaluator): result = invoke_cli(["plan", "--no-prompts", "--auto-apply", "--skip-tests"]) assert result.returncode == 0 - assert "Physical layer updated" in result.stdout assert "Virtual layer updated" in result.stdout # render the query to ensure our macro is being invoked @@ -175,7 +174,6 @@ def do_something(evaluator): ] ) assert result.returncode == 0 - assert "Physical layer updated" in result.stdout assert "Virtual layer updated" in result.stdout log_file_contents = last_log_file_contents() @@ -236,7 +234,6 @@ def do_something(evaluator): result = invoke_cli(["plan", "--no-prompts", "--auto-apply", "--skip-tests"]) assert result.returncode == 0 - assert "Physical layer updated" in result.stdout assert "Virtual layer updated" in result.stdout # clear cache to ensure we are forced to reload everything @@ -266,7 +263,6 @@ def do_something(evaluator): ) assert result.returncode == 0 assert "Apply - Backfill Tables [y/n]:" in result.stdout - assert "Physical layer updated" not in result.stdout # the invalid snapshot in state should not prevent a plan if --select-model is used on it (since the local version can be rendered) result = invoke_cli( @@ -343,7 +339,6 @@ def test_model_selector_tags_picks_up_both_remote_and_local( result = invoke_cli(["plan", "--no-prompts", "--auto-apply", "--skip-tests"]) assert result.returncode == 0 - assert "Physical layer updated" in result.stdout assert "Virtual layer updated" in result.stdout # add a new model locally with tag:a diff --git a/tests/conftest.py b/tests/conftest.py index ad09deff6f..c51d3e5912 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -271,13 +271,16 @@ def push_plan(context: Context, plan: Plan) -> None: context.default_catalog, ) deployability_index = DeployabilityIndex.create(context.snapshots.values()) - evaluatable_plan = plan.to_evaluatable() + evaluatable_plan = plan.to_evaluatable().copy(update={"skip_backfill": True}) stages = plan_stages.build_plan_stages( evaluatable_plan, context.state_sync, context.default_catalog ) for stage in stages: if isinstance(stage, plan_stages.CreateSnapshotRecordsStage): plan_evaluator.visit_create_snapshot_records_stage(stage, evaluatable_plan) + elif isinstance(stage, plan_stages.PhysicalLayerSchemaCreationStage): + stage.deployability_index = deployability_index + plan_evaluator.visit_physical_layer_schema_creation_stage(stage, evaluatable_plan) elif isinstance(stage, plan_stages.PhysicalLayerUpdateStage): stage.deployability_index = deployability_index plan_evaluator.visit_physical_layer_update_stage(stage, evaluatable_plan) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py index 1b7d54a2d9..6ababb7c71 100644 --- a/tests/core/engine_adapter/integration/test_integration.py +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -1914,6 +1914,11 @@ def test_sushi(ctx: TestContext, tmp_path_factory: pytest.TempPathFactory): ], personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()], ) + config.before_all = [ + f"CREATE SCHEMA IF NOT EXISTS {raw_test_schema}", + f"DROP VIEW IF EXISTS {raw_test_schema}.demographics", + f"CREATE VIEW {raw_test_schema}.demographics AS (SELECT 1 AS customer_id, '00000' AS zip)", + ] # To enable parallelism in integration tests config.gateways = {ctx.gateway: config.gateways[ctx.gateway]} @@ -2132,6 +2137,8 @@ def validate_comments( } for model_name, comment in comments.items(): + if not model_name in layer_models: + continue layer_table_name = layer_models[model_name]["table_name"] table_kind = "VIEW" if layer_models[model_name]["is_view"] else "BASE TABLE" diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py index caa7843726..5923afa217 100644 --- a/tests/core/engine_adapter/test_mssql.py +++ b/tests/core/engine_adapter/test_mssql.py @@ -249,6 +249,7 @@ def test_incremental_by_time_datetimeoffset_precision( end="2020-01-02", execution_time="2020-01-02", snapshots={}, + target_table_exists=True, ) assert adapter.cursor.execute.call_args_list[0][0][0] == ( diff --git a/tests/core/test_context.py b/tests/core/test_context.py index a94ba74a20..3b7c5bd51d 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -689,9 +689,7 @@ def test_plan_apply_populates_cache(copy_to_temp_path, mocker): config_content = f.read() # Add cache_dir to the test_config definition - config_content = config_content.replace( - 'test_config = Config(\n gateways={"in_memory": GatewayConfig(connection=DuckDBConnectionConfig())},\n default_gateway="in_memory",\n plan=PlanConfig(\n auto_categorize_changes=CategorizerConfig(\n sql=AutoCategorizationMode.SEMI, python=AutoCategorizationMode.OFF\n )\n ),\n model_defaults=model_defaults,\n)', - f"""test_config = Config( + config_content += f"""test_config_cache_dir = Config( gateways={{"in_memory": GatewayConfig(connection=DuckDBConnectionConfig())}}, default_gateway="in_memory", plan=PlanConfig( @@ -701,14 +699,14 @@ def test_plan_apply_populates_cache(copy_to_temp_path, mocker): ), model_defaults=model_defaults, cache_dir="{custom_cache_dir.as_posix()}", -)""", - ) + before_all=before_all, +)""" with open(config_py_path, "w") as f: f.write(config_content) # Create context with the test config - context = Context(paths=sushi_path, config="test_config") + context = Context(paths=sushi_path, config="test_config_cache_dir") custom_cache_dir = context.cache_dir assert "custom_cache" in str(custom_cache_dir) assert (custom_cache_dir / "optimized_query").exists() @@ -733,7 +731,7 @@ def test_plan_apply_populates_cache(copy_to_temp_path, mocker): # New context should load same models and create the cache for optimized_query and model_definition initial_model_count = len(context.models) - context2 = Context(paths=context.path, config="test_config") + context2 = Context(paths=context.path, config="test_config_cache_dir") cached_model_count = len(context2.models) assert initial_model_count == cached_model_count > 0 @@ -1778,14 +1776,14 @@ def test_plan_environment_statements(tmp_path: pathlib.Path): ); @IF( - @runtime_stage = 'evaluating', + @runtime_stage IN ('evaluating', 'creating'), SET VARIABLE stats_model_start = now() ); SELECT 1 AS cola; @IF( - @runtime_stage = 'evaluating', + @runtime_stage IN ('evaluating', 'creating'), INSERT INTO analytic_stats (physical_table, evaluation_start, evaluation_end, evaluation_time) VALUES (@resolve_template('@{schema_name}.@{table_name}'), getvariable('stats_model_start'), now(), now() - getvariable('stats_model_start')) ); @@ -1851,11 +1849,11 @@ def access_adapter(evaluator): assert ( model.pre_statements[0].sql() - == "@IF(@runtime_stage = 'evaluating', SET VARIABLE stats_model_start = NOW())" + == "@IF(@runtime_stage IN ('evaluating', 'creating'), SET VARIABLE stats_model_start = NOW())" ) assert ( model.post_statements[0].sql() - == "@IF(@runtime_stage = 'evaluating', INSERT INTO analytic_stats (physical_table, evaluation_start, evaluation_end, evaluation_time) VALUES (@resolve_template('@{schema_name}.@{table_name}'), GETVARIABLE('stats_model_start'), NOW(), NOW() - GETVARIABLE('stats_model_start')))" + == "@IF(@runtime_stage IN ('evaluating', 'creating'), INSERT INTO analytic_stats (physical_table, evaluation_start, evaluation_end, evaluation_time) VALUES (@resolve_template('@{schema_name}.@{table_name}'), GETVARIABLE('stats_model_start'), NOW(), NOW() - GETVARIABLE('stats_model_start')))" ) stats_table = context.fetchdf("select * from memory.analytic_stats").to_dict() diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 72d8964a71..fc129424f4 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -37,7 +37,7 @@ DuckDBConnectionConfig, TableNamingConvention, ) -from sqlmesh.core.config.common import EnvironmentSuffixTarget +from sqlmesh.core.config.common import EnvironmentSuffixTarget, VirtualEnvironmentMode from sqlmesh.core.console import Console, get_console from sqlmesh.core.context import Context from sqlmesh.core.config.categorizer import CategorizerConfig @@ -488,12 +488,6 @@ def test_full_history_restatement_model_regular_plan_preview_enabled( waiter_as_customer_snapshot = context.get_snapshot( "sushi.waiter_as_customer_by_day", raise_if_missing=True ) - count_customers_active_snapshot = context.get_snapshot( - "sushi.count_customers_active", raise_if_missing=True - ) - count_customers_inactive_snapshot = context.get_snapshot( - "sushi.count_customers_inactive", raise_if_missing=True - ) plan = context.plan_builder("dev", skip_tests=True, enable_preview=True).build() @@ -959,8 +953,9 @@ def test_new_forward_only_model(init_and_plan_context: t.Callable): @time_machine.travel("2023-01-08 15:00:00 UTC") def test_plan_set_choice_is_reflected_in_missing_intervals(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) + context, _ = init_and_plan_context("examples/sushi") + context.upsert_model(context.get_model("sushi.top_waiters").copy(update={"kind": FullKind()})) + context.plan("prod", skip_tests=True, no_prompts=True, auto_apply=True) model_name = "sushi.waiter_revenue_by_day" @@ -1461,6 +1456,18 @@ def test_indirect_non_breaking_downstream_of_forward_only(init_and_plan_context: plan = context.plan_builder("prod", skip_tests=True).build() assert plan.start == to_timestamp("2023-01-01") assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiter_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), SnapshotIntervals( snapshot_id=non_breaking_snapshot.snapshot_id, intervals=[ @@ -1485,8 +1492,9 @@ def test_indirect_non_breaking_downstream_of_forward_only(init_and_plan_context: @time_machine.travel("2023-01-08 15:00:00 UTC") def test_breaking_only_impacts_immediate_children(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) + context, _ = init_and_plan_context("examples/sushi") + context.upsert_model(context.get_model("sushi.top_waiters").copy(update={"kind": FullKind()})) + context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True) breaking_model = context.get_model("sushi.orders") breaking_model = breaking_model.copy(update={"stamp": "force new version"}) @@ -2206,9 +2214,7 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( context.upsert_model(add_projection_to_model(t.cast(SqlModel, forward_only_model))) forward_only_model_snapshot_id = context.get_snapshot(forward_only_model_name).snapshot_id full_downstream_model_snapshot_id = context.get_snapshot(full_downstream_model_name).snapshot_id - full_downstream_model_2_snapshot_id = context.get_snapshot( - view_downstream_model_name - ).snapshot_id + view_downstream_model_snapshot_id = context.get_snapshot(view_downstream_model_name).snapshot_id dev_plan = context.plan("dev", auto_apply=True, no_prompts=True, enable_preview=False) assert ( dev_plan.snapshots[forward_only_model_snapshot_id].change_category @@ -2219,7 +2225,7 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( == SnapshotChangeCategory.INDIRECT_NON_BREAKING ) assert ( - dev_plan.snapshots[full_downstream_model_2_snapshot_id].change_category + dev_plan.snapshots[view_downstream_model_snapshot_id].change_category == SnapshotChangeCategory.INDIRECT_NON_BREAKING ) assert not dev_plan.missing_intervals @@ -2238,9 +2244,7 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( new_full_downstream_model = load_sql_based_model(new_full_downstream_model_expressions) context.upsert_model(new_full_downstream_model) full_downstream_model_snapshot_id = context.get_snapshot(full_downstream_model_name).snapshot_id - full_downstream_model_2_snapshot_id = context.get_snapshot( - view_downstream_model_name - ).snapshot_id + view_downstream_model_snapshot_id = context.get_snapshot(view_downstream_model_name).snapshot_id dev_plan = context.plan( "dev", categorizer_config=CategorizerConfig.all_full(), @@ -2253,12 +2257,12 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( == SnapshotChangeCategory.BREAKING ) assert ( - dev_plan.snapshots[full_downstream_model_2_snapshot_id].change_category + dev_plan.snapshots[view_downstream_model_snapshot_id].change_category == SnapshotChangeCategory.INDIRECT_BREAKING ) assert len(dev_plan.missing_intervals) == 2 assert dev_plan.missing_intervals[0].snapshot_id == full_downstream_model_snapshot_id - assert dev_plan.missing_intervals[1].snapshot_id == full_downstream_model_2_snapshot_id + assert dev_plan.missing_intervals[1].snapshot_id == view_downstream_model_snapshot_id # Check that the representative view hasn't been created yet. assert not context.engine_adapter.table_exists( @@ -2272,9 +2276,7 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( # Finally, make a non-breaking change to the full model in the same dev environment. context.upsert_model(add_projection_to_model(t.cast(SqlModel, new_full_downstream_model))) full_downstream_model_snapshot_id = context.get_snapshot(full_downstream_model_name).snapshot_id - full_downstream_model_2_snapshot_id = context.get_snapshot( - view_downstream_model_name - ).snapshot_id + view_downstream_model_snapshot_id = context.get_snapshot(view_downstream_model_name).snapshot_id dev_plan = context.plan( "dev", categorizer_config=CategorizerConfig.all_full(), @@ -2287,10 +2289,13 @@ def test_indirect_non_breaking_view_model_non_representative_snapshot( == SnapshotChangeCategory.NON_BREAKING ) assert ( - dev_plan.snapshots[full_downstream_model_2_snapshot_id].change_category + dev_plan.snapshots[view_downstream_model_snapshot_id].change_category == SnapshotChangeCategory.INDIRECT_NON_BREAKING ) + # Deploy changes to prod + context.plan("prod", auto_apply=True, no_prompts=True) + # Check that the representative view has been created. assert context.engine_adapter.table_exists( context.get_snapshot(view_downstream_model_name).table_name() @@ -2659,6 +2664,66 @@ def test_virtual_environment_mode_dev_only_model_kind_change(init_and_plan_conte assert data_objects[0].type == "table" +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_model_kind_change_incremental( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + + forward_only_model_name = "memory.sushi.test_forward_only_model" + forward_only_model_expressions = d.parse( + f""" + MODEL ( + name {forward_only_model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + ), + ); + + SELECT '2023-01-01' AS ds, 'value' AS value; + """ + ) + forward_only_model = load_sql_based_model(forward_only_model_expressions) + forward_only_model = forward_only_model.copy( + update={"virtual_environment_mode": VirtualEnvironmentMode.DEV_ONLY} + ) + context.upsert_model(forward_only_model) + + context.plan("prod", auto_apply=True, no_prompts=True) + + # Change to view + model = context.get_model(forward_only_model_name) + original_kind = model.kind + model = model.copy(update={"kind": ViewKind()}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"test_forward_only_model"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "view" + + model = model.copy(update={"kind": original_kind}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"test_forward_only_model"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "table" + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_virtual_environment_mode_dev_only_model_kind_change_with_follow_up_changes_in_dev( init_and_plan_context: t.Callable, @@ -3146,7 +3211,7 @@ def test_restatement_plan_clears_correct_intervals_across_environments(tmp_path: cron '@daily' ); - select account_id, name, date from test.external_table; + select 1 as account_id, date from test.external_table; """ with open(models_dir / "model1.sql", "w") as f: f.write(model1) @@ -3945,11 +4010,12 @@ def test_plan_snapshot_table_exists_for_promoted_snapshot(init_and_plan_context: "UPDATE sqlmesh._environments SET finalized_ts = NULL WHERE name = 'dev'" ) - model = context.get_model("sushi.customers") - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) - context.plan( - "dev", select_models=["sushi.customers"], auto_apply=True, no_prompts=True, skip_tests=True + "prod", + restate_models=["sushi.top_waiters"], + auto_apply=True, + no_prompts=True, + skip_tests=True, ) assert context.engine_adapter.table_exists(top_waiters_snapshot.table_name()) @@ -4526,7 +4592,7 @@ def test_plan_repairs_unrenderable_snapshot_state( plan = plan_builder.build() assert plan.directly_modified == {target_snapshot.snapshot_id} if not forward_only: - assert {i.snapshot_id for i in plan.missing_intervals} == {target_snapshot.snapshot_id} + assert target_snapshot.snapshot_id in {i.snapshot_id for i in plan.missing_intervals} plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING) plan = plan_builder.build() @@ -5265,11 +5331,14 @@ def test_multi(mocker): assert context.fetchdf("select * from after_1").to_dict()["repo_1"][0] == "repo_1" assert context.fetchdf("select * from after_2").to_dict()["repo_2"][0] == "repo_2" + old_context = context context = Context( paths=["examples/multi/repo_1"], - state_sync=context.state_sync, + state_sync=old_context.state_sync, gateway="memory", ) + context._engine_adapter = old_context.engine_adapter + del context.engine_adapters model = context.get_model("bronze.a") assert model.project == "repo_1" @@ -5862,7 +5931,7 @@ def get_default_catalog_and_non_tables( ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) assert len(prod_views) == 16 assert len(dev_views) == 0 - assert len(user_default_tables) == 21 + assert len(user_default_tables) == 15 assert state_metadata.schemas == ["sqlmesh"] assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( { @@ -5881,7 +5950,7 @@ def get_default_catalog_and_non_tables( ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) assert len(prod_views) == 16 assert len(dev_views) == 16 - assert len(user_default_tables) == 21 + assert len(user_default_tables) == 16 assert len(non_default_tables) == 0 assert state_metadata.schemas == ["sqlmesh"] assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( @@ -5901,7 +5970,7 @@ def get_default_catalog_and_non_tables( ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) assert len(prod_views) == 16 assert len(dev_views) == 32 - assert len(user_default_tables) == 21 + assert len(user_default_tables) == 16 assert len(non_default_tables) == 0 assert state_metadata.schemas == ["sqlmesh"] assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( @@ -5922,7 +5991,7 @@ def get_default_catalog_and_non_tables( ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) assert len(prod_views) == 16 assert len(dev_views) == 16 - assert len(user_default_tables) == 21 + assert len(user_default_tables) == 16 assert len(non_default_tables) == 0 assert state_metadata.schemas == ["sqlmesh"] assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( @@ -6074,13 +6143,13 @@ def test_restatement_of_full_model_with_start(init_and_plan_context: t.Callable) @time_machine.travel("2023-01-08 15:00:00 UTC") def test_restatement_should_not_override_environment_statements(init_and_plan_context: t.Callable): context, _ = init_and_plan_context("examples/sushi") - context.config.before_all = ["SELECT 'test_before_all';"] + context.config.before_all = ["SELECT 'test_before_all';", *context.config.before_all] context.load() context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) prod_env_statements = context.state_reader.get_environment_statements(c.PROD) - assert prod_env_statements[0].before_all == ["SELECT 'test_before_all';"] + assert prod_env_statements[0].before_all[0] == "SELECT 'test_before_all';" context.plan( restate_models=["sushi.waiter_revenue_by_day"], @@ -6090,7 +6159,7 @@ def test_restatement_should_not_override_environment_statements(init_and_plan_co ) prod_env_statements = context.state_reader.get_environment_statements(c.PROD) - assert prod_env_statements[0].before_all == ["SELECT 'test_before_all';"] + assert prod_env_statements[0].before_all[0] == "SELECT 'test_before_all';" @time_machine.travel("2023-01-08 15:00:00 UTC") @@ -6138,7 +6207,7 @@ def test_plan_production_environment_statements(tmp_path: Path): ); @IF( - @runtime_stage = 'creating', + @runtime_stage IN ('evaluating', 'creating'), INSERT INTO schema_names_for_prod (physical_schema_name) VALUES (@resolve_template('@{schema_name}')) ); @@ -6893,17 +6962,7 @@ def plan_with_output(ctx: Context, environment: str): assert "New environment `dev` will be created from `prod`" in output.stdout assert "Differences from the `prod` environment" in output.stdout - assert ( - """MODEL ( - name test.a, -+ owner test, - kind FULL - ) - SELECT -- 5 AS col -+ 10 AS col""" - in output.stdout - ) + assert "Directly Modified: test__dev.a" in output.stdout # Case 6: Ensure that target environment and create_from environment are not the same output = plan_with_output(ctx, "prod") @@ -7299,11 +7358,7 @@ def test_engine_adapters_multi_repo_all_gateways_gathered(copy_to_temp_path): def test_physical_table_naming_strategy_table_only(copy_to_temp_path: t.Callable): sushi_context = Context( paths=copy_to_temp_path("examples/sushi"), - config=Config( - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - default_connection=DuckDBConnectionConfig(), - physical_table_naming_convention=TableNamingConvention.TABLE_ONLY, - ), + config="table_only_naming_config", ) assert sushi_context.config.physical_table_naming_convention == TableNamingConvention.TABLE_ONLY @@ -7334,11 +7389,7 @@ def test_physical_table_naming_strategy_table_only(copy_to_temp_path: t.Callable def test_physical_table_naming_strategy_hash_md5(copy_to_temp_path: t.Callable): sushi_context = Context( paths=copy_to_temp_path("examples/sushi"), - config=Config( - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - default_connection=DuckDBConnectionConfig(), - physical_table_naming_convention=TableNamingConvention.HASH_MD5, - ), + config="hash_md5_naming_config", ) assert sushi_context.config.physical_table_naming_convention == TableNamingConvention.HASH_MD5 diff --git a/tests/core/test_plan_evaluator.py b/tests/core/test_plan_evaluator.py index a3735b08ed..575f5ae742 100644 --- a/tests/core/test_plan_evaluator.py +++ b/tests/core/test_plan_evaluator.py @@ -69,10 +69,12 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot): stages = plan_stages.build_plan_stages( evaluatable_plan, sushi_context.state_sync, sushi_context.default_catalog ) - assert isinstance(stages[0], plan_stages.CreateSnapshotRecordsStage) - evaluator.visit_create_snapshot_records_stage(stages[0], evaluatable_plan) - assert isinstance(stages[1], plan_stages.PhysicalLayerUpdateStage) - evaluator.visit_physical_layer_update_stage(stages[1], evaluatable_plan) + assert isinstance(stages[1], plan_stages.CreateSnapshotRecordsStage) + evaluator.visit_create_snapshot_records_stage(stages[1], evaluatable_plan) + assert isinstance(stages[2], plan_stages.PhysicalLayerSchemaCreationStage) + evaluator.visit_physical_layer_schema_creation_stage(stages[2], evaluatable_plan) + assert isinstance(stages[3], plan_stages.BackfillStage) + evaluator.visit_backfill_stage(stages[3], evaluatable_plan) assert ( len(sushi_context.state_sync.get_snapshots([new_model_snapshot, new_view_model_snapshot])) diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py index aedf50e26f..7b172caf6a 100644 --- a/tests/core/test_plan_stages.py +++ b/tests/core/test_plan_stages.py @@ -12,6 +12,7 @@ AfterAllStage, AuditOnlyRunStage, PhysicalLayerUpdateStage, + PhysicalLayerSchemaCreationStage, CreateSnapshotRecordsStage, BeforeAllStage, BackfillStage, @@ -134,18 +135,14 @@ def test_build_plan_stages_basic( snapshot_a.snapshot_id, snapshot_b.snapshot_id, } - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[1] - assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) assert len(physical_stage.snapshots) == 2 assert {s.snapshot_id for s in physical_stage.snapshots} == { snapshot_a.snapshot_id, snapshot_b.snapshot_id, } - assert {s.snapshot_id for s in physical_stage.snapshots_with_missing_intervals} == { - snapshot_a.snapshot_id, - snapshot_b.snapshot_id, - } assert physical_stage.deployability_index == DeployabilityIndex.all_deployable() # Verify BackfillStage @@ -252,9 +249,9 @@ def test_build_plan_stages_with_before_all_and_after_all( snapshot_b.snapshot_id, } - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[2] - assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) assert len(physical_stage.snapshots) == 2 assert {s.snapshot_id for s in physical_stage.snapshots} == { snapshot_a.snapshot_id, @@ -356,13 +353,12 @@ def test_build_plan_stages_select_models( snapshot_b.snapshot_id, } - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[1] - assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) assert len(physical_stage.snapshots) == 1 assert {s.snapshot_id for s in physical_stage.snapshots} == {snapshot_a.snapshot_id} assert physical_stage.deployability_index == DeployabilityIndex.all_deployable() - assert physical_stage.snapshots_with_missing_intervals == {snapshot_a.snapshot_id} # Verify BackfillStage backfill_stage = stages[2] @@ -446,7 +442,7 @@ def test_build_plan_stages_basic_no_backfill( stages = build_plan_stages(plan, state_reader, None) # Verify stages - assert len(stages) == 7 + assert len(stages) == 8 # Verify CreateSnapshotRecordsStage create_snapshot_records_stage = stages[0] @@ -456,8 +452,17 @@ def test_build_plan_stages_basic_no_backfill( snapshot_a.snapshot_id, snapshot_b.snapshot_id, } - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + # Verify PhysicalLayerUpdateStage + physical_stage = stages[2] assert isinstance(physical_stage, PhysicalLayerUpdateStage) assert len(physical_stage.snapshots) == 2 assert {s.snapshot_id for s in physical_stage.snapshots} == { @@ -466,28 +471,28 @@ def test_build_plan_stages_basic_no_backfill( } # Verify BackfillStage - backfill_stage = stages[2] + backfill_stage = stages[3] assert isinstance(backfill_stage, BackfillStage) assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() assert backfill_stage.snapshot_to_intervals == {} # Verify EnvironmentRecordUpdateStage - assert isinstance(stages[3], EnvironmentRecordUpdateStage) - assert stages[3].no_gaps_snapshot_names == {snapshot_a.name, snapshot_b.name} + assert isinstance(stages[4], EnvironmentRecordUpdateStage) + assert stages[4].no_gaps_snapshot_names == {snapshot_a.name, snapshot_b.name} # Verify UnpauseStage - assert isinstance(stages[4], UnpauseStage) - assert {s.name for s in stages[4].promoted_snapshots} == {snapshot_a.name, snapshot_b.name} + assert isinstance(stages[5], UnpauseStage) + assert {s.name for s in stages[5].promoted_snapshots} == {snapshot_a.name, snapshot_b.name} # Verify VirtualLayerUpdateStage - virtual_stage = stages[5] + virtual_stage = stages[6] assert isinstance(virtual_stage, VirtualLayerUpdateStage) assert len(virtual_stage.promoted_snapshots) == 2 assert len(virtual_stage.demoted_snapshots) == 0 assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"', '"b"'} # Verify FinalizeEnvironmentStage - assert isinstance(stages[6], FinalizeEnvironmentStage) + assert isinstance(stages[7], FinalizeEnvironmentStage) def test_build_plan_stages_restatement( @@ -558,9 +563,9 @@ def test_build_plan_stages_restatement( # Verify stages assert len(stages) == 5 - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[0] - assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) assert len(physical_stage.snapshots) == 2 assert {s.snapshot_id for s in physical_stage.snapshots} == { snapshot_a.snapshot_id, @@ -679,17 +684,15 @@ def test_build_plan_stages_forward_only( new_snapshot_b.snapshot_id, } - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[1] - assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) assert len(physical_stage.snapshots) == 2 assert {s.snapshot_id for s in physical_stage.snapshots} == { new_snapshot_a.snapshot_id, new_snapshot_b.snapshot_id, } - assert physical_stage.deployability_index == DeployabilityIndex.create( - [new_snapshot_a, new_snapshot_b] - ) + assert physical_stage.deployability_index == DeployabilityIndex.all_deployable() # Verify EnvironmentRecordUpdateStage assert isinstance(stages[2], EnvironmentRecordUpdateStage) @@ -808,9 +811,9 @@ def test_build_plan_stages_forward_only_dev( new_snapshot_b.snapshot_id, } - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[1] - assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) assert len(physical_stage.snapshots) == 2 assert {s.snapshot_id for s in physical_stage.snapshots} == { new_snapshot_a.snapshot_id, @@ -921,7 +924,7 @@ def _get_snapshots(snapshot_ids: t.List[SnapshotId]) -> t.Dict[SnapshotId, Snaps stages = build_plan_stages(plan, state_reader, None) # Verify stages - assert len(stages) == 7 + assert len(stages) == 8 # Verify CreateSnapshotRecordsStage create_snapshot_records_stage = stages[0] @@ -932,8 +935,20 @@ def _get_snapshots(snapshot_ids: t.List[SnapshotId]) -> t.Dict[SnapshotId, Snaps new_snapshot_b.snapshot_id, } - # Verify PhysicalLayerUpdateStage + # Verify PhysicalLayerSchemaCreationStage physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + assert physical_stage.deployability_index == DeployabilityIndex.create( + [new_snapshot_a, new_snapshot_b] + ) + + # Verify PhysicalLayerUpdateStage + physical_stage = stages[2] assert isinstance(physical_stage, PhysicalLayerUpdateStage) assert len(physical_stage.snapshots) == 2 assert {s.snapshot_id for s in physical_stage.snapshots} == { @@ -945,28 +960,28 @@ def _get_snapshots(snapshot_ids: t.List[SnapshotId]) -> t.Dict[SnapshotId, Snaps ) # Verify AuditOnlyRunStage - audit_only_stage = stages[2] + audit_only_stage = stages[3] assert isinstance(audit_only_stage, AuditOnlyRunStage) assert len(audit_only_stage.snapshots) == 1 assert audit_only_stage.snapshots[0].snapshot_id == new_snapshot_a.snapshot_id # Verify BackfillStage - backfill_stage = stages[3] + backfill_stage = stages[4] assert isinstance(backfill_stage, BackfillStage) assert len(backfill_stage.snapshot_to_intervals) == 0 # Verify EnvironmentRecordUpdateStage - assert isinstance(stages[4], EnvironmentRecordUpdateStage) + assert isinstance(stages[5], EnvironmentRecordUpdateStage) # Verify VirtualLayerUpdateStage - virtual_stage = stages[5] + virtual_stage = stages[6] assert isinstance(virtual_stage, VirtualLayerUpdateStage) assert len(virtual_stage.promoted_snapshots) == 2 assert len(virtual_stage.demoted_snapshots) == 0 assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"', '"b"'} # Verify FinalizeEnvironmentStage - assert isinstance(stages[6], FinalizeEnvironmentStage) + assert isinstance(stages[7], FinalizeEnvironmentStage) def test_build_plan_stages_forward_only_ensure_finalized_snapshots( @@ -1046,7 +1061,7 @@ def test_build_plan_stages_forward_only_ensure_finalized_snapshots( assert len(stages) == 8 assert isinstance(stages[0], CreateSnapshotRecordsStage) - assert isinstance(stages[1], PhysicalLayerUpdateStage) + assert isinstance(stages[1], PhysicalLayerSchemaCreationStage) assert isinstance(stages[2], EnvironmentRecordUpdateStage) assert isinstance(stages[3], MigrateSchemasStage) assert isinstance(stages[4], BackfillStage) @@ -1120,7 +1135,7 @@ def test_build_plan_stages_removed_model( # Verify stages assert len(stages) == 5 - assert isinstance(stages[0], PhysicalLayerUpdateStage) + assert isinstance(stages[0], PhysicalLayerSchemaCreationStage) assert isinstance(stages[1], BackfillStage) assert isinstance(stages[2], EnvironmentRecordUpdateStage) assert isinstance(stages[3], VirtualLayerUpdateStage) @@ -1202,7 +1217,7 @@ def test_build_plan_stages_environment_suffix_target_changed( # Verify stages assert len(stages) == 5 - assert isinstance(stages[0], PhysicalLayerUpdateStage) + assert isinstance(stages[0], PhysicalLayerSchemaCreationStage) assert isinstance(stages[1], BackfillStage) assert isinstance(stages[2], EnvironmentRecordUpdateStage) assert isinstance(stages[3], VirtualLayerUpdateStage) @@ -1303,17 +1318,14 @@ def test_build_plan_stages_indirect_non_breaking_view_migration( assert len(stages) == 8 assert isinstance(stages[0], CreateSnapshotRecordsStage) - assert isinstance(stages[1], PhysicalLayerUpdateStage) + assert isinstance(stages[1], PhysicalLayerSchemaCreationStage) assert isinstance(stages[2], BackfillStage) assert isinstance(stages[3], EnvironmentRecordUpdateStage) - assert isinstance(stages[4], MigrateSchemasStage) - assert isinstance(stages[5], UnpauseStage) + assert isinstance(stages[4], UnpauseStage) + assert isinstance(stages[5], BackfillStage) assert isinstance(stages[6], VirtualLayerUpdateStage) assert isinstance(stages[7], FinalizeEnvironmentStage) - migrate_schemas_stage = stages[4] - assert {s.snapshot_id for s in migrate_schemas_stage.snapshots} == {new_snapshot_c.snapshot_id} - def test_build_plan_stages_virtual_environment_mode_filtering( make_snapshot, mocker: MockerFixture diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 742642794f..b74aa3480e 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -21,6 +21,9 @@ interval_diff, compute_interval_params, SnapshotToIntervals, + EvaluateNode, + SchedulingUnit, + DummyNode, ) from sqlmesh.core.signal import signal from sqlmesh.core.snapshot import ( @@ -160,9 +163,10 @@ def test_incremental_by_unique_key_kind_dag( batches = get_batched_missing_intervals(scheduler, start, end, end) dag = scheduler._dag(batches) assert dag.graph == { - ( + EvaluateNode( unique_by_key_snapshot.name, - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), 0), + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), + batch_index=0, ): set(), } @@ -202,60 +206,66 @@ def test_incremental_time_self_reference_dag( assert dag.graph == { # Only run one day at a time and each day relies on the previous days - ( + EvaluateNode( incremental_self_snapshot.name, - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), 0), + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, ): set(), - ( + EvaluateNode( incremental_self_snapshot.name, - ((to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), 1), + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=1, ): { - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), 0), - ) + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ), }, - ( + EvaluateNode( incremental_self_snapshot.name, - ((to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), 2), + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=2, ): { - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), 1), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=1, ), }, - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), 3), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=3, ): { - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), 2), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=2, + ), + }, + DummyNode(snapshot_name=incremental_self_snapshot.name): { + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=1, + ), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=2, + ), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=3, ), }, - ( - incremental_self_snapshot.name, - ((to_timestamp(0), to_timestamp(0)), -1), - ): set( - [ - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), 0), - ), - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), 1), - ), - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), 2), - ), - ( - incremental_self_snapshot.name, - ((to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), 3), - ), - ] - ), } @@ -266,16 +276,26 @@ def test_incremental_time_self_reference_dag( 2, 2, { - ( - '"test_model"', - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), + batch_index=0, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-03"), to_timestamp("2023-01-05")), 1), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-05")), + batch_index=1, ): set(), - ('"test_model"', ((to_timestamp("2023-01-05"), to_timestamp("2023-01-07")), 2)): { - ('"test_model"', ((to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), 0)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-05"), to_timestamp("2023-01-07")), + batch_index=2, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), + batch_index=0, + ), }, }, ), @@ -283,26 +303,53 @@ def test_incremental_time_self_reference_dag( 1, 3, { - ( - '"test_model"', - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), 1), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + batch_index=1, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), 2), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=2, ): set(), - ('"test_model"', ((to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), 3)): { - ('"test_model"', ((to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), 0)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=3, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ), }, - ('"test_model"', ((to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), 4)): { - ('"test_model"', ((to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), 1)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + batch_index=4, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + batch_index=1, + ), }, - ('"test_model"', ((to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), 5)): { - ('"test_model"', ((to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), 2)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=5, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=2, + ), }, }, ), @@ -310,29 +357,35 @@ def test_incremental_time_self_reference_dag( 1, 10, { - ( - '"test_model"', - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), 1), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + batch_index=1, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), 2), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=2, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), 3), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=3, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), 4), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + batch_index=4, ): set(), - ( - '"test_model"', - ((to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), 5), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=5, ): set(), }, ), @@ -340,9 +393,10 @@ def test_incremental_time_self_reference_dag( 10, 10, { - ( - '"test_model"', - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), + batch_index=0, ): set(), }, ), @@ -350,9 +404,10 @@ def test_incremental_time_self_reference_dag( 10, 1, { - ( - '"test_model"', - ((to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), + batch_index=0, ): set(), }, ), @@ -364,7 +419,7 @@ def test_incremental_batch_concurrency( get_batched_missing_intervals, batch_size: int, batch_concurrency: int, - expected_graph: t.Dict[str, t.Any], + expected_graph: t.Dict[SchedulingUnit, t.Set[SchedulingUnit]], ): start = to_datetime("2023-01-01") end = to_datetime("2023-01-07") @@ -392,7 +447,7 @@ def test_incremental_batch_concurrency( batches = get_batched_missing_intervals(scheduler, start, end, end) dag = scheduler._dag(batches) - graph = {k: v for k, v in dag.graph.items() if k[1][1] != -1} # exclude the terminal node.} + graph = {k: v for k, v in dag.graph.items() if isinstance(k, EvaluateNode)} assert graph == expected_graph diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index bcb704ba48..bce091595c 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -2052,6 +2052,7 @@ def test_deployability_index(make_snapshot): snapshot_f.parents = (snapshot_e.snapshot_id, snapshot_a.snapshot_id) snapshot_g = make_snapshot(SqlModel(name="g", query=parse_one("SELECT 1"))) + snapshot_g.intervals = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] snapshot_g.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) snapshot_g.parents = (snapshot_e.snapshot_id,) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index b05d567cd2..60931b1602 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -197,12 +197,6 @@ def x(evaluator, y=None) -> None: execute_calls = [call([parse_one('CREATE TABLE "hook_called"')])] adapter_mock.execute.assert_has_calls(execute_calls) - adapter_mock.create_schema.assert_has_calls( - [ - call(to_schema("sqlmesh__test_schema")), - ] - ) - common_kwargs = dict( target_columns_to_types={"a": exp.DataType.build("int")}, table_format=None, @@ -612,7 +606,6 @@ def test_evaluate_materialized_view_with_partitioned_by_cluster_by( execute_mock.assert_has_calls( [ - call("CREATE SCHEMA IF NOT EXISTS `sqlmesh__test_schema`"), call( f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`" ), @@ -812,7 +805,6 @@ def test_create_only_dev_table_exists(mocker: MockerFixture, adapter_mock, make_ evaluator = SnapshotEvaluator(adapter_mock) evaluator.create([snapshot], {}) - adapter_mock.create_schema.assert_called_once_with(to_schema("sqlmesh__test_schema")) adapter_mock.create_view.assert_not_called() adapter_mock.get_data_objects.assert_called_once_with( schema_("sqlmesh__test_schema"), @@ -847,7 +839,6 @@ def test_create_new_forward_only_model(mocker: MockerFixture, adapter_mock, make evaluator = SnapshotEvaluator(adapter_mock) evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) - adapter_mock.create_schema.assert_called_once_with(to_schema("sqlmesh__test_schema")) # Only non-deployable table should be created adapter_mock.create_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.dev_version}__dev", @@ -867,66 +858,57 @@ def test_create_new_forward_only_model(mocker: MockerFixture, adapter_mock, make adapter_mock.get_data_objects.assert_called_once_with( schema_("sqlmesh__test_schema"), { - f"test_schema__test_model__{snapshot.version}", f"test_schema__test_model__{snapshot.dev_version}__dev", }, ) @pytest.mark.parametrize( - "deployability_index, snapshot_category, forward_only, deployability_flags", + "deployability_index, snapshot_category, forward_only", [ - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, False, [False]), - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.NON_BREAKING, False, [False]), - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, True, [True]), + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, False), + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.NON_BREAKING, False), + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, True), ( DeployabilityIndex.all_deployable(), SnapshotChangeCategory.INDIRECT_BREAKING, False, - [False], ), ( DeployabilityIndex.all_deployable(), SnapshotChangeCategory.INDIRECT_NON_BREAKING, False, - [True], ), - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.METADATA, False, [True]), + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.METADATA, False), ( DeployabilityIndex.none_deployable(), SnapshotChangeCategory.BREAKING, False, - [True, False], ), ( DeployabilityIndex.none_deployable(), SnapshotChangeCategory.NON_BREAKING, False, - [True, False], ), ( DeployabilityIndex.none_deployable(), SnapshotChangeCategory.BREAKING, True, - [True], ), ( DeployabilityIndex.none_deployable(), SnapshotChangeCategory.INDIRECT_BREAKING, False, - [True, False], ), ( DeployabilityIndex.none_deployable(), SnapshotChangeCategory.INDIRECT_NON_BREAKING, False, - [True], ), ( DeployabilityIndex.none_deployable(), SnapshotChangeCategory.METADATA, False, - [True], ), ], ) @@ -935,7 +917,6 @@ def test_create_tables_exist( mocker: MockerFixture, adapter_mock, deployability_index: DeployabilityIndex, - deployability_flags: t.List[bool], snapshot_category: SnapshotChangeCategory, forward_only: bool, ): @@ -967,8 +948,9 @@ def test_create_tables_exist( adapter_mock.get_data_objects.assert_called_once_with( schema_("sqlmesh__db"), { - f"db__model__{snapshot.version}" if not flag else f"db__model__{snapshot.version}__dev" - for flag in set(deployability_flags + [False]) + f"db__model__{snapshot.version}" + if deployability_index.is_deployable(snapshot) + else f"db__model__{snapshot.version}__dev", }, ) adapter_mock.create_schema.assert_not_called() @@ -1005,24 +987,11 @@ def test_create_prod_table_exists_forward_only(mocker: MockerFixture, adapter_mo adapter_mock.get_data_objects.assert_called_once_with( schema_("sqlmesh__test_schema"), { - f"test_schema__test_model__{snapshot.version}__dev", f"test_schema__test_model__{snapshot.version}", }, ) - adapter_mock.create_schema.assert_called_once_with(to_schema("sqlmesh__test_schema")) - adapter_mock.create_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", - target_columns_to_types={"a": exp.DataType.build("int")}, - table_format=None, - storage_format=None, - partitioned_by=[], - partition_interval_unit=None, - clustered_by=[], - table_properties={}, - table_description=None, - column_descriptions=None, - ) + adapter_mock.create_table.assert_not_called() def test_create_view_non_deployable_snapshot(mocker: MockerFixture, adapter_mock, make_snapshot): @@ -1304,15 +1273,7 @@ def test_migrate_missing_table(mocker: MockerFixture, make_snapshot, make_mocked evaluator.migrate([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) - adapter.cursor.execute.assert_has_calls( - [ - call('CREATE TABLE "pre" ("a" INT)'), - call( - 'CREATE TABLE IF NOT EXISTS "sqlmesh__test_schema"."test_schema__test_model__1" AS SELECT "c" AS "c", "a" AS "a" FROM "tbl" AS "tbl" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\' AND FALSE LIMIT 0' - ), - call('DROP TABLE "pre"'), - ] - ) + adapter.cursor.execute.assert_not_called() @pytest.mark.parametrize( @@ -1350,13 +1311,7 @@ def test_migrate_view( evaluator.migrate([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) - adapter.cursor.execute.assert_has_calls( - [ - call( - 'CREATE OR REPLACE VIEW "sqlmesh__test_schema"."test_schema__test_model__1" ("c", "a") AS SELECT "c" AS "c", "a" AS "a" FROM "tbl" AS "tbl"' - ) - ] - ) + adapter.cursor.execute.assert_not_called() def test_migrate_snapshot_data_object_type_mismatch( @@ -1369,7 +1324,7 @@ def test_migrate_snapshot_data_object_type_mismatch( adapter, "get_data_object", return_value=DataObject( - schema="sqlmesh__test_schema", name="test_schema__test_model__1", type="table" + schema="sqlmesh__test_schema", name="test_schema__test_model__1", type="view" ), ) mocker.patch.object(adapter, "table_exists", return_value=False) @@ -1378,22 +1333,20 @@ def test_migrate_snapshot_data_object_type_mismatch( model = SqlModel( name="test_schema.test_model", - kind=ViewKind(), + kind=FullKind(), storage_format="parquet", query=parse_one("SELECT c, a FROM tbl"), ) snapshot = make_snapshot(model, version="1") snapshot.change_category = SnapshotChangeCategory.BREAKING snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions evaluator.migrate([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) adapter.cursor.execute.assert_has_calls( [ - call('DROP TABLE IF EXISTS "sqlmesh__test_schema"."test_schema__test_model__1"'), - call( - 'CREATE VIEW "sqlmesh__test_schema"."test_schema__test_model__1" AS SELECT "c" AS "c", "a" AS "a" FROM "tbl" AS "tbl"' - ), + call('DROP VIEW IF EXISTS "sqlmesh__test_schema"."test_schema__test_model__1"'), ] ) @@ -1404,6 +1357,7 @@ def test_evaluate_creation_duckdb( date_kwargs: t.Dict[str, str], ): evaluator = SnapshotEvaluator(create_engine_adapter(lambda: duck_conn, "duckdb")) + evaluator.create_physical_schemas([snapshot], DeployabilityIndex.all_deployable()) evaluator.create([snapshot], {}) version = snapshot.version @@ -1440,6 +1394,7 @@ def assert_tables_exist() -> None: def test_migrate_duckdb(snapshot: Snapshot, duck_conn, make_snapshot): evaluator = SnapshotEvaluator(create_engine_adapter(lambda: duck_conn, "duckdb")) + evaluator.create_physical_schemas([snapshot], DeployabilityIndex.all_deployable()) evaluator.create([snapshot], {}) updated_model_dict = snapshot.model.dict() @@ -1610,10 +1565,10 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot) ), ] - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) adapter_mock.create_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, table_format=None, storage_format=None, @@ -1634,61 +1589,17 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot) adapter_mock.get_alter_expressions.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", ignore_destructive=False, ) adapter_mock.alter_table.assert_called_once_with([]) adapter_mock.drop_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source" + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp" ) -def test_create_clone_in_dev_missing_table(mocker: MockerFixture, adapter_mock, make_snapshot): - adapter_mock.SUPPORTS_CLONING = True - adapter_mock.get_alter_expressions.return_value = [] - evaluator = SnapshotEvaluator(adapter_mock) - - model = load_sql_based_model( - parse( # type: ignore - """ - MODEL ( - name test_schema.test_model, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column ds, - forward_only true, - ) - ); - - SELECT 1::INT as a, ds::DATE FROM a; - """ - ), - ) - - snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) - snapshot.previous_versions = snapshot.all_versions - - evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) - - adapter_mock.create_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.dev_version}__dev", - target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, - table_format=None, - storage_format=None, - partitioned_by=[exp.to_column("ds", quoted=True)], - partition_interval_unit=IntervalUnit.DAY, - clustered_by=[], - table_properties={}, - table_description=None, - column_descriptions=None, - ) - - adapter_mock.clone_table.assert_not_called() - adapter_mock.alter_table.assert_not_called() - - def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock.SUPPORTS_CLONING = True adapter_mock.get_alter_expressions.return_value = [] @@ -1724,7 +1635,7 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m ] with pytest.raises(SnapshotCreationFailedError): - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) adapter_mock.clone_table.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", @@ -1735,7 +1646,7 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m adapter_mock.get_alter_expressions.assert_called_once_with( f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", ignore_destructive=False, ) @@ -1743,10 +1654,10 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m adapter_mock.drop_table.assert_has_calls( [ - call(f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev"), call( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source" + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp" ), + call(f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev"), ] ) @@ -1787,10 +1698,10 @@ def test_create_clone_in_dev_self_referencing( ), ] - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) adapter_mock.create_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev__schema_migration_source", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, table_format=None, storage_format=None, @@ -1802,16 +1713,16 @@ def test_create_clone_in_dev_self_referencing( column_descriptions=None, ) - # Make sure the dry run references the correct ("...__schema_migration_source") table. + # Make sure the dry run references the correct ("..._schema_tmp") table. table_alias = ( "test_model" if not use_this_model - else f"test_schema__test_model__{snapshot.version}__dev__schema_migration_source" + else f"test_schema__test_model__{snapshot.version}__dev_schema_tmp" ) dry_run_query = adapter_mock.fetchall.call_args[0][0].sql() assert ( dry_run_query - == f'SELECT CAST(1 AS INT) AS "a", CAST("ds" AS DATE) AS "ds" FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev__schema_migration_source" AS "{table_alias}" /* test_schema.test_model */ WHERE FALSE LIMIT 0' + == f'SELECT CAST(1 AS INT) AS "a", CAST("ds" AS DATE) AS "ds" FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev_schema_tmp" AS "{table_alias}" /* test_schema.test_model */ WHERE FALSE LIMIT 0' ) @@ -1919,7 +1830,7 @@ def test_forward_only_snapshot_for_added_model(mocker: MockerFixture, adapter_mo snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) common_create_args = dict( target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, @@ -1963,7 +1874,7 @@ def test_create_scd_type_2_by_time(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) common_kwargs = dict( target_columns_to_types={ @@ -1990,11 +1901,6 @@ def test_create_scd_type_2_by_time(adapter_mock, make_snapshot): column_descriptions=None, **common_kwargs, ), - call( - snapshot.table_name(), - column_descriptions={}, - **common_kwargs, - ), ] ) @@ -2021,7 +1927,7 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) query = parse_one( """SELECT *, CAST(NULL AS TIMESTAMPTZ) AS valid_from, CAST(NULL AS TIMESTAMPTZ) AS valid_to FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0""" @@ -2047,7 +1953,6 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot): column_descriptions=None, **common_kwargs, ), - call(snapshot.table_name(), query, None, column_descriptions={}, **common_kwargs), ] ) @@ -2142,7 +2047,7 @@ def test_create_scd_type_2_by_column(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) common_kwargs = dict( target_columns_to_types={ @@ -2167,7 +2072,6 @@ def test_create_scd_type_2_by_column(adapter_mock, make_snapshot): snapshot.table_name(is_deployable=False), **{**common_kwargs, "column_descriptions": None}, ), - call(snapshot.table_name(), **{**common_kwargs, "column_descriptions": {}}), ] ) @@ -2193,7 +2097,7 @@ def test_create_ctas_scd_type_2_by_column(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) query = parse_one( """SELECT *, CAST(NULL AS TIMESTAMP) AS valid_from, CAST(NULL AS TIMESTAMP) AS valid_to FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0""" @@ -2218,9 +2122,6 @@ def test_create_ctas_scd_type_2_by_column(adapter_mock, make_snapshot): None, **{**common_kwargs, "column_descriptions": None}, ), - call( - snapshot.table_name(), query, None, **{**common_kwargs, "column_descriptions": {}} - ), ] ) @@ -3152,14 +3053,7 @@ def create_log_table(evaluator, view_name): == f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a")' ) - post_calls = call_args[3][0][0] - assert len(post_calls) == 1 - assert ( - post_calls[0].sql(dialect="postgres") - == f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a")' - ) - - on_virtual_update_calls = call_args[4][0][0] + on_virtual_update_calls = call_args[2][0][0] assert ( on_virtual_update_calls[0].sql(dialect="postgres") == 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"' @@ -3237,7 +3131,7 @@ def model_with_statements(context, **kwargs): ) call_args = adapter_mock.execute.call_args_list - on_virtual_update_call = call_args[4][0][0][0] + on_virtual_update_call = call_args[2][0][0][0] assert ( on_virtual_update_call.sql(dialect="postgres") == 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model_3" /* db.test_model_3 */("id")' @@ -3513,9 +3407,47 @@ def test_create_managed(adapter_mock, make_snapshot, mocker: MockerFixture): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.all_deployable()) + + adapter_mock.create_managed_table.assert_called_with( + table_name=snapshot.table_name(), + query=mocker.ANY, + target_columns_to_types=model.columns_to_types, + partitioned_by=model.partitioned_by, + clustered_by=model.clustered_by, + table_properties=model.physical_properties, + table_description=model.description, + column_descriptions=model.column_descriptions, + table_format=None, + ) + + +def test_create_managed_dev(adapter_mock, make_snapshot, mocker: MockerFixture): + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind MANAGED, + physical_properties ( + warehouse = 'small', + target_lag = '10 minutes' + ), + clustered_by a + ); + + select a, b from foo; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) - # first call to evaluation_strategy.create(), is_table_deployable=False triggers a normal table adapter_mock.ctas.assert_called_once_with( f"{snapshot.table_name()}__dev", mocker.ANY, @@ -3530,19 +3462,6 @@ def test_create_managed(adapter_mock, make_snapshot, mocker: MockerFixture): column_descriptions=None, ) - # second call to evaluation_strategy.create(), is_table_deployable=True and is_snapshot_deployable=True triggers a managed table - adapter_mock.create_managed_table.assert_called_with( - table_name=snapshot.table_name(), - query=mocker.ANY, - target_columns_to_types=model.columns_to_types, - partitioned_by=model.partitioned_by, - clustered_by=model.clustered_by, - table_properties=model.physical_properties, - table_description=model.description, - column_descriptions=model.column_descriptions, - table_format=None, - ) - def test_evaluate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): evaluator = SnapshotEvaluator(adapter_mock) @@ -3695,7 +3614,11 @@ def test_create_managed_forward_only_with_previous_version_doesnt_clone_for_dev_ ), ] - evaluator.create(target_snapshots=[snapshot], snapshots={}) + evaluator.create( + target_snapshots=[snapshot], + snapshots={}, + deployability_index=DeployabilityIndex.none_deployable(), + ) # We dont clone managed tables to create dev previews, we use normal tables adapter_mock.clone_table.assert_not_called() @@ -3707,114 +3630,6 @@ def test_create_managed_forward_only_with_previous_version_doesnt_clone_for_dev_ assert adapter_mock.ctas.call_args_list[0].args[0] == snapshot.table_name(is_deployable=False) -@pytest.mark.parametrize( - "deployability_index, snapshot_category, forward_only, deployability_flags", - [ - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, False, [True]), - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.NON_BREAKING, False, [True]), - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, True, [False]), - ( - DeployabilityIndex.all_deployable(), - SnapshotChangeCategory.INDIRECT_BREAKING, - False, - [True], - ), - ( - DeployabilityIndex.all_deployable(), - SnapshotChangeCategory.INDIRECT_NON_BREAKING, - False, - [False], - ), - (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.METADATA, False, [False]), - ( - DeployabilityIndex.none_deployable(), - SnapshotChangeCategory.BREAKING, - False, - [False, True], - ), - ( - DeployabilityIndex.none_deployable(), - SnapshotChangeCategory.NON_BREAKING, - False, - [False, True], - ), - ( - DeployabilityIndex.none_deployable(), - SnapshotChangeCategory.BREAKING, - True, - [False], - ), - ( - DeployabilityIndex.none_deployable(), - SnapshotChangeCategory.INDIRECT_BREAKING, - False, - [False, True], - ), - ( - DeployabilityIndex.none_deployable(), - SnapshotChangeCategory.INDIRECT_NON_BREAKING, - False, - [False], - ), - ( - DeployabilityIndex.none_deployable(), - SnapshotChangeCategory.METADATA, - False, - [False], - ), - ], -) -def test_create_snapshot( - snapshot: Snapshot, - mocker: MockerFixture, - adapter_mock, - deployability_index: DeployabilityIndex, - deployability_flags: t.List[bool], - snapshot_category: SnapshotChangeCategory, - forward_only: bool, -): - adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") - adapter_mock.dialect = "duckdb" - - evaluator = SnapshotEvaluator(adapter_mock) - snapshot.categorize_as(category=snapshot_category, forward_only=forward_only) - evaluator._create_snapshot( - snapshot=snapshot, - snapshots={}, - deployability_flags=deployability_flags, - deployability_index=deployability_index, - on_complete=None, - allow_destructive_snapshots=set(), - ) - - common_kwargs: t.Dict[str, t.Any] = dict( - target_columns_to_types={"a": exp.DataType.build("int")}, - table_format=None, - storage_format=None, - partitioned_by=[], - partition_interval_unit=None, - clustered_by=[], - table_properties={}, - table_description=None, - ) - - tables_created = [ - call( - snapshot.table_name(is_deployable=is_deployable), - column_descriptions=(None if not is_deployable else {}), - **common_kwargs, - ) - for is_deployable in deployability_flags - ] - - adapter_mock.create_table.assert_has_calls(tables_created) - - # Even if one or two (prod and dev) tables are created, the dry run should be conducted once - adapter_mock.fetchall.assert_called_once_with( - parse_one('SELECT CAST("a" AS INT) AS "a" FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0') - ) - - def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" @@ -3838,7 +3653,6 @@ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_moc ) adapter_mock.drop_data_object_on_type_mismatch.return_value = False - evaluator.create([new_snapshot], {}) evaluator.migrate([new_snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) common_kwargs: t.Dict[str, t.Any] = dict( @@ -3860,7 +3674,7 @@ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_moc **common_kwargs, ), call( - new_snapshot.table_name(is_deployable=False), + f"{new_snapshot.table_name()}_schema_tmp", target_columns_to_types={ "a": exp.DataType.build("int"), "b": exp.DataType.build("int"), @@ -3886,7 +3700,7 @@ def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_moc adapter_mock.get_alter_expressions.assert_called_once_with( snapshot.table_name(), - new_snapshot.table_name(is_deployable=False), + f"{new_snapshot.table_name()}_schema_tmp", ignore_destructive=False, ) @@ -3925,7 +3739,8 @@ def test_migrate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): adapter_mock.create_table.assert_not_called() adapter_mock.create_managed_table.assert_not_called() - adapter_mock.ctas.assert_not_called() + adapter_mock.ctas.assert_called_once() + adapter_mock.reset_mock() # schema changes - exception thrown adapter_mock.get_alter_expressions.return_value = [exp.Alter()] @@ -3945,7 +3760,7 @@ def test_migrate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): ) adapter_mock.create_table.assert_not_called() - adapter_mock.ctas.assert_not_called() + adapter_mock.ctas.assert_called_once() adapter_mock.create_managed_table.assert_not_called() @@ -4155,7 +3970,7 @@ def columns(table_name): # The second mock adapter has to be called only for the gateway-specific model adapter_mock.get_alter_expressions.assert_called_once_with( snapshot_2.table_name(True), - snapshot_2.table_name(False), + f"{snapshot_2.table_name(True)}_schema_tmp", ignore_destructive=False, ) diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index f34a1c6c74..44b6cd7911 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -939,11 +939,11 @@ def test_connection_args(tmp_path): dbt_project_dir = "tests/fixtures/dbt/sushi_test" config = sqlmesh_config(dbt_project_dir) - assert config.gateways["in_memory"].connection.register_comments - - config = sqlmesh_config(dbt_project_dir, register_comments=False) assert not config.gateways["in_memory"].connection.register_comments + config = sqlmesh_config(dbt_project_dir, register_comments=True) + assert config.gateways["in_memory"].connection.register_comments + def test_custom_dbt_loader(): from sqlmesh.core.loader import SqlMeshLoader diff --git a/tests/integrations/github/cicd/test_integration.py b/tests/integrations/github/cicd/test_integration.py index 3fb965f310..d69311fb3d 100644 --- a/tests/integrations/github/cicd/test_integration.py +++ b/tests/integrations/github/cicd/test_integration.py @@ -89,10 +89,17 @@ def test_linter( mock_pull_request.merged = False mock_pull_request.merge = mocker.MagicMock() + before_all = [ + "CREATE SCHEMA IF NOT EXISTS raw", + "DROP VIEW IF EXISTS raw.demographics", + "CREATE VIEW raw.demographics AS (SELECT 1 AS customer_id, '00000' AS zip)", + ] + # Case 1: Test for linter errors config = Config( model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=LinterConfig(enabled=True, rules="ALL"), + before_all=before_all, ) controller = make_controller( @@ -142,6 +149,7 @@ def test_linter( config = Config( model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=LinterConfig(enabled=True, warn_rules="ALL"), + before_all=before_all, ) controller = make_controller( diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 586d8abb6d..ae0742f1db 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -82,7 +82,9 @@ def use_terminal_console(func): def test_wrapper(*args, **kwargs): orig_console = get_console() try: - set_console(TerminalConsole()) + new_console = TerminalConsole() + new_console.console.no_color = True + set_console(new_console) func(*args, **kwargs) finally: set_console(orig_console)