diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 43283ead90..3b9fce7f4e 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -428,6 +428,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -575,6 +576,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: pass @@ -1056,6 +1058,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Update the snapshot evaluation progress.""" if ( @@ -3639,6 +3642,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3808,11 +3812,15 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: - message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + + if auto_restatement_triggers: + message += f" | auto_restatement_triggers=[{', '.join(trigger.name for trigger in auto_restatement_triggers)}]" if audit_only: - message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + message = f"Audited {snapshot.name} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" self._write(message) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index e787e57a23..8096ffece1 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -415,6 +415,7 @@ def run_merged_intervals( selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, audit_only: bool = False, + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: """Runs precomputed batches of missing intervals. @@ -531,6 +532,9 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, + auto_restatement_triggers=auto_restatement_triggers.get( + snapshot.snapshot_id + ), ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot( @@ -736,8 +740,11 @@ def _run_or_audit( for s_id, interval in (remove_intervals or {}).items(): self.snapshots[s_id].remove_interval(interval) + all_auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} if auto_restatement_enabled: - auto_restated_intervals = apply_auto_restatements(self.snapshots, execution_time) + auto_restated_intervals, all_auto_restatement_triggers = apply_auto_restatements( + self.snapshots, execution_time + ) self.state_sync.add_snapshots_intervals(auto_restated_intervals) self.state_sync.update_auto_restatements( {s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()} @@ -758,6 +765,14 @@ def _run_or_audit( if not merged_intervals: return CompletionStatus.NOTHING_TO_DO + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + if all_auto_restatement_triggers: + merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals} + auto_restatement_triggers = { + s_id: all_auto_restatement_triggers.get(s_id, []) + for s_id in merged_intervals_snapshots + } + errors, _ = self.run_merged_intervals( merged_intervals=merged_intervals, deployability_index=deployability_index, @@ -768,6 +783,7 @@ def _run_or_audit( end=end, run_environment_statements=run_environment_statements, audit_only=audit_only, + auto_restatement_triggers=auto_restatement_triggers, ) return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index ec5a883f7f..45740d9810 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -21,7 +21,7 @@ from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind from sqlmesh.core.model.definition import _Model from sqlmesh.core.node import IntervalUnit, NodeType -from sqlmesh.utils import sanitize_name +from sqlmesh.utils import sanitize_name, unique from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( TimeLike, @@ -2180,7 +2180,7 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]: def apply_auto_restatements( snapshots: t.Dict[SnapshotId, Snapshot], execution_time: TimeLike -) -> t.List[SnapshotIntervals]: +) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, t.List[SnapshotId]]]: """Applies auto restatements to the snapshots. This operation results in the removal of intervals for snapshots that are ready to be restated based @@ -2195,6 +2195,7 @@ def apply_auto_restatements( A list of SnapshotIntervals with **new** intervals that need to be restated. """ dag = snapshots_to_dag(snapshots.values()) + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} auto_restated_intervals_per_snapshot: t.Dict[SnapshotId, Interval] = {} for s_id in dag: if s_id not in snapshots: @@ -2209,6 +2210,7 @@ def apply_auto_restatements( for parent_s_id in snapshot.parents if parent_s_id in auto_restated_intervals_per_snapshot ] + upstream_triggers = [] if next_auto_restated_interval: logger.info( "Calculated the next auto restated interval (%s, %s) for snapshot %s", @@ -2218,6 +2220,18 @@ def apply_auto_restatements( ) auto_restated_intervals.append(next_auto_restated_interval) + # auto-restated snapshot is its own trigger + upstream_triggers = [s_id] + else: + # inherit each parent's auto-restatement triggers (if any) + for parent_s_id in snapshot.parents: + if parent_s_id in auto_restatement_triggers: + upstream_triggers.extend(auto_restatement_triggers[parent_s_id]) + + # remove duplicate triggers, retaining order and keeping first seen of duplicates + if upstream_triggers: + auto_restatement_triggers[s_id] = unique(upstream_triggers) + if auto_restated_intervals: auto_restated_interval_start = sys.maxsize auto_restated_interval_end = -sys.maxsize @@ -2247,20 +2261,22 @@ def apply_auto_restatements( snapshot.apply_pending_restatement_intervals() snapshot.update_next_auto_restatement_ts(execution_time) - - return [ - SnapshotIntervals( - name=snapshots[s_id].name, - identifier=None, - version=snapshots[s_id].version, - dev_version=None, - intervals=[], - dev_intervals=[], - pending_restatement_intervals=[interval], - ) - for s_id, interval in auto_restated_intervals_per_snapshot.items() - if s_id in snapshots - ] + return ( + [ + SnapshotIntervals( + name=snapshots[s_id].name, + identifier=None, + version=snapshots[s_id].version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[interval], + ) + for s_id, interval in auto_restated_intervals_per_snapshot.items() + if s_id in snapshots + ], + auto_restatement_triggers, + ) def parent_snapshots_by_name( diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index fc129424f4..827d84e8b9 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -1862,6 +1862,82 @@ def test_select_unchanged_model_for_backfill(init_and_plan_context: t.Callable): assert {o.name for o in schema_objects} == {"waiter_revenue_by_day", "top_waiters"} +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixture): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # auto-restatement triggers + orders = context.get_model("sushi.orders") + orders_kind = { + **orders.kind.dict(), + "auto_restatement_cron": "@hourly", + } + orders_kwargs = { + **orders.dict(), + "kind": orders_kind, + } + context.upsert_model(PythonModel.parse_obj(orders_kwargs)) + + order_items = context.get_model("sushi.order_items") + order_items_kind = { + **order_items.kind.dict(), + "auto_restatement_cron": "@hourly", + } + order_items_kwargs = { + **order_items.dict(), + "kind": order_items_kind, + } + context.upsert_model(PythonModel.parse_obj(order_items_kwargs)) + + waiter_revenue_by_day = context.get_model("sushi.waiter_revenue_by_day") + waiter_revenue_by_day_kind = { + **waiter_revenue_by_day.kind.dict(), + "auto_restatement_cron": "@hourly", + } + waiter_revenue_by_day_kwargs = { + **waiter_revenue_by_day.dict(), + "kind": waiter_revenue_by_day_kind, + } + context.upsert_model(SqlModel.parse_obj(waiter_revenue_by_day_kwargs)) + + context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full()) + + scheduler = context.scheduler() + + import sqlmesh + + spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals") + + with time_machine.travel("2023-01-09 00:00:01 UTC"): + scheduler.run( + environment=c.PROD, + start="2023-01-01", + auto_restatement_enabled=True, + ) + + assert spy.called + + actual_triggers = spy.call_args.kwargs["auto_restatement_triggers"] + actual_triggers = {k: v for k, v in actual_triggers.items() if v} + assert len(actual_triggers) == 12 + + for id, trigger in actual_triggers.items(): + model_name = id.name.replace('"memory"."sushi".', "").replace('"', "") + auto_restatement_triggers = [ + t.name.replace('"memory"."sushi".', "").replace('"', "") for t in trigger + ] + + if model_name in ("orders", "order_items", "waiter_revenue_by_day"): + assert auto_restatement_triggers == [model_name] + elif model_name in ("customer_revenue_lifetime", "customer_revenue_by_day"): + assert sorted(auto_restatement_triggers) == sorted(["orders", "order_items"]) + elif model_name == "top_waiters": + assert auto_restatement_triggers == ["waiter_revenue_by_day"] + else: + assert auto_restatement_triggers == ["orders"] + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_max_interval_end_per_model_not_applied_when_end_is_provided( init_and_plan_context: t.Callable, @@ -6962,7 +7038,18 @@ def plan_with_output(ctx: Context, environment: str): assert "New environment `dev` will be created from `prod`" in output.stdout assert "Differences from the `prod` environment" in output.stdout - assert "Directly Modified: test__dev.a" in output.stdout + stdout_rstrip = "\n".join([line.rstrip() for line in output.stdout.split("\n")]) + assert ( + """MODEL ( + name test.a, ++ owner test, + kind FULL + ) + SELECT +- 5 AS col ++ 10 AS col""" + in stdout_rstrip + ) # Case 6: Ensure that target environment and create_from environment are not the same output = plan_with_output(ctx, "prod") @@ -7705,7 +7792,7 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -7749,7 +7836,7 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 2 as id, 3 as new_column, @@ -7803,9 +7890,9 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): start '2023-01-01', cron '@daily' ); - + SELECT - *, + *, 2 as id, CAST(4 AS STRING) as new_column, @start_ds as ds @@ -7844,8 +7931,8 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): start '2023-01-01', cron '@daily' ); - - SELECT + + SELECT *, 2 as id, CAST(5 AS STRING) as new_column, @@ -7905,7 +7992,7 @@ def test_incremental_by_unique_key_model_ignore_destructive_change(tmp_path: Pat cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -7949,7 +8036,7 @@ def test_incremental_by_unique_key_model_ignore_destructive_change(tmp_path: Pat cron '@daily' ); - SELECT + SELECT *, 2 as id, 3 as new_column, @@ -8016,7 +8103,7 @@ def test_incremental_unmanaged_model_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8059,7 +8146,7 @@ def test_incremental_unmanaged_model_ignore_destructive_change(tmp_path: Path): ); SELECT - *, + *, 2 as id, 3 as new_column, @start_ds as ds @@ -8240,7 +8327,7 @@ def test_scd_type_2_by_column_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8285,7 +8372,7 @@ def test_scd_type_2_by_column_ignore_destructive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 3 as new_column, @start_ds as ds @@ -8352,7 +8439,7 @@ def test_incremental_partition_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8396,7 +8483,7 @@ def test_incremental_partition_ignore_destructive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 3 as new_column, @start_ds as ds @@ -8467,7 +8554,7 @@ def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: cron '@daily' ); - SELECT + SELECT id, name, ds @@ -8479,7 +8566,7 @@ def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: (models_dir / "test_model.sql").write_text(initial_model) initial_test = f""" - + test_test_model: model: test_model inputs: @@ -8534,8 +8621,8 @@ def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: start '2023-01-01', cron '@daily' ); - - SELECT + + SELECT id, new_column, ds diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index bce091595c..db61b9cabf 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -3102,7 +3102,7 @@ def test_apply_auto_restatements(make_snapshot): (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), ] - restated_intervals = apply_auto_restatements( + restated_intervals, _ = apply_auto_restatements( { snapshot_a.snapshot_id: snapshot_a, snapshot_b.snapshot_id: snapshot_b, @@ -3239,7 +3239,7 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot): snapshot_b.add_interval("2020-01-01", "2020-01-05") assert snapshot_a.snapshot_id in snapshot_b.parents - restated_intervals = apply_auto_restatements( + restated_intervals, _ = apply_auto_restatements( { snapshot_a.snapshot_id: snapshot_a, snapshot_b.snapshot_id: snapshot_b, @@ -3279,6 +3279,116 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot): ] +def test_auto_restatement_triggers(make_snapshot): + # Auto restatements: + # a, c, d + # dag: + # a -> b + # a -> c + # [b, c, d] -> e + model_a = SqlModel( + name="test_model_a", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT 1 as ds"), + ) + snapshot_a = make_snapshot(model_a, version="1") + snapshot_a.add_interval("2020-01-01", "2020-01-05") + snapshot_a.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_b = SqlModel( + name="test_model_b", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_b = make_snapshot(model_b, nodes={model_a.fqn: model_a}, version="1") + snapshot_b.add_interval("2020-01-01", "2020-01-05") + + model_c = SqlModel( + name="test_model_c", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_c = make_snapshot(model_c, nodes={model_a.fqn: model_a}, version="1") + snapshot_c.add_interval("2020-01-01", "2020-01-05") + snapshot_c.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_d = SqlModel( + name="test_model_d", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT 1 as ds"), + ) + snapshot_d = make_snapshot(model_d, version="1") + snapshot_d.add_interval("2020-01-01", "2020-01-05") + snapshot_d.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_e = SqlModel( + name="test_model_e", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + start="2020-01-01", + cron="@daily", + query=parse_one( + "SELECT ds from test_model_b UNION ALL SELECT ds from test_model_c UNION ALL SELECT ds from test_model_d" + ), + ) + snapshot_e = make_snapshot( + model_e, + nodes={ + model_a.fqn: model_a, + model_b.fqn: model_b, + model_c.fqn: model_c, + model_d.fqn: model_d, + }, + version="1", + ) + snapshot_e.add_interval("2020-01-01", "2020-01-05") + + _, auto_restatement_triggers = apply_auto_restatements( + { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_c.snapshot_id: snapshot_c, + snapshot_d.snapshot_id: snapshot_d, + snapshot_e.snapshot_id: snapshot_e, + }, + "2020-01-06 10:01:00", + ) + + assert auto_restatement_triggers[snapshot_a.snapshot_id] == [snapshot_a.snapshot_id] + assert auto_restatement_triggers[snapshot_c.snapshot_id] == [snapshot_c.snapshot_id] + assert auto_restatement_triggers[snapshot_d.snapshot_id] == [snapshot_d.snapshot_id] + assert auto_restatement_triggers[snapshot_b.snapshot_id] == [snapshot_a.snapshot_id] + # a via b, c and d directly + assert sorted(auto_restatement_triggers[snapshot_e.snapshot_id]) == [ + snapshot_a.snapshot_id, + snapshot_c.snapshot_id, + snapshot_d.snapshot_id, + ] + + def test_render_signal(make_snapshot, mocker): @signal() def check_types(batch, env: str, sql: list[SQL], table: exp.Table, default: int = 0): diff --git a/web/server/console.py b/web/server/console.py index 2cda0af697..902a85418c 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -9,7 +9,7 @@ from sqlmesh.core.console import TerminalConsole from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.plan.definition import EvaluatablePlan -from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo +from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId from sqlmesh.core.test import ModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.utils.date import now_timestamp @@ -142,6 +142,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: if audit_only: return