diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index 82223dd807..91c8c6ff14 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -574,23 +574,31 @@ def _get_audit_only_snapshots( ) -> t.Dict[SnapshotId, Snapshot]: metadata_snapshots = [] for snapshot in new_snapshots.values(): - if not snapshot.is_metadata or not snapshot.is_model or not snapshot.evaluatable: + if ( + not snapshot.is_metadata + or not snapshot.is_model + or not snapshot.evaluatable + or not snapshot.previous_version + ): continue metadata_snapshots.append(snapshot) # Bulk load all the previous snapshots - previous_snapshots = self.state_reader.get_snapshots( - [ - s.previous_version.snapshot_id(s.name) - for s in metadata_snapshots - if s.previous_version - ] - ).values() + previous_snapshot_ids = [ + s.previous_version.snapshot_id(s.name) for s in metadata_snapshots if s.previous_version + ] + previous_snapshots = { + s.name: s for s in self.state_reader.get_snapshots(previous_snapshot_ids).values() + } # Check if any of the snapshots have modifications to the audits field by comparing the hashes audit_snapshots = {} - for snapshot, previous_snapshot in zip(metadata_snapshots, previous_snapshots): + for snapshot in metadata_snapshots: + if snapshot.name not in previous_snapshots: + continue + + previous_snapshot = previous_snapshots[snapshot.name] new_audits_hash = snapshot.model.audit_metadata_hash() previous_audit_hash = previous_snapshot.model.audit_metadata_hash() diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 210aff230d..dc6499c1a3 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -626,6 +626,9 @@ def _dag( dag = DAG[SchedulingUnit]() for snapshot_id in snapshot_dag: + if snapshot_id.name not in self.snapshots_by_name: + continue + snapshot = self.snapshots_by_name[snapshot_id.name] intervals = intervals_per_snapshot.get(snapshot.name, []) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 2781909c88..0e779481fd 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -6294,6 +6294,31 @@ def test_restatement_shouldnt_backfill_beyond_prod_intervals(init_and_plan_conte ].intervals[-1][1] == to_timestamp("2023-01-08 00:00:00 UTC") +@time_machine.travel("2023-01-08 15:00:00 UTC") +@use_terminal_console +def test_audit_only_metadata_change(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Add a new audit + model = context.get_model("sushi.waiter_revenue_by_day") + audits = model.audits.copy() + audits.append(("number_of_rows", {"threshold": exp.Literal.number(1)})) + model = model.copy(update={"audits": audits}) + context.upsert_model(model) + + plan = context.plan_builder("prod", skip_tests=True).build() + assert len(plan.new_snapshots) == 2 + assert all(s.change_category.is_metadata for s in plan.new_snapshots) + assert not plan.missing_intervals + + with capture_output() as output: + context.apply(plan) + + assert "Auditing models" in output.stdout + assert model.name in output.stdout + + def initial_add(context: Context, environment: str): assert not context.state_reader.get_environment(environment)