Skip to content

Commit 37f6fa4

Browse files
committed
List all upstream auto-restated models
1 parent 06f2788 commit 37f6fa4

File tree

5 files changed

+128
-29
lines changed

5 files changed

+128
-29
lines changed

sqlmesh/core/console.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def update_snapshot_evaluation_progress(
428428
num_audits_passed: int,
429429
num_audits_failed: int,
430430
audit_only: bool = False,
431-
auto_restatement_trigger: t.Optional[SnapshotId] = None,
431+
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
432432
) -> None:
433433
"""Updates the snapshot evaluation progress."""
434434

@@ -576,7 +576,7 @@ def update_snapshot_evaluation_progress(
576576
num_audits_passed: int,
577577
num_audits_failed: int,
578578
audit_only: bool = False,
579-
auto_restatement_trigger: t.Optional[SnapshotId] = None,
579+
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
580580
) -> None:
581581
pass
582582

@@ -1058,7 +1058,7 @@ def update_snapshot_evaluation_progress(
10581058
num_audits_passed: int,
10591059
num_audits_failed: int,
10601060
audit_only: bool = False,
1061-
auto_restatement_trigger: t.Optional[SnapshotId] = None,
1061+
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
10621062
) -> None:
10631063
"""Update the snapshot evaluation progress."""
10641064
if (
@@ -3638,7 +3638,7 @@ def update_snapshot_evaluation_progress(
36383638
num_audits_passed: int,
36393639
num_audits_failed: int,
36403640
audit_only: bool = False,
3641-
auto_restatement_trigger: t.Optional[SnapshotId] = None,
3641+
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
36423642
) -> None:
36433643
view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id]
36443644

@@ -3808,12 +3808,12 @@ def update_snapshot_evaluation_progress(
38083808
num_audits_passed: int,
38093809
num_audits_failed: int,
38103810
audit_only: bool = False,
3811-
auto_restatement_trigger: t.Optional[SnapshotId] = None,
3811+
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
38123812
) -> None:
38133813
message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
38143814

3815-
if auto_restatement_trigger:
3816-
message += f" | evaluation_triggered_by={auto_restatement_trigger.name}"
3815+
if auto_restatement_triggers:
3816+
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in auto_restatement_triggers)}"
38173817

38183818
if audit_only:
38193819
message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

sqlmesh/core/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def run_merged_intervals(
374374
run_environment_statements: bool = False,
375375
audit_only: bool = False,
376376
restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
377-
auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {},
377+
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {},
378378
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
379379
"""Runs precomputed batches of missing intervals.
380380
@@ -477,7 +477,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
477477
evaluation_duration_ms,
478478
num_audits - num_audits_failed,
479479
num_audits_failed,
480-
auto_restatement_trigger=auto_restatement_triggers.get(snapshot.snapshot_id),
480+
auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id),
481481
)
482482

483483
try:
@@ -641,7 +641,7 @@ def _run_or_audit(
641641
for s_id, interval in (remove_intervals or {}).items():
642642
self.snapshots[s_id].remove_interval(interval)
643643

644-
auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {}
644+
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
645645
if auto_restatement_enabled:
646646
auto_restated_intervals, auto_restatement_triggers = apply_auto_restatements(
647647
self.snapshots, execution_time

sqlmesh/core/snapshot/definition.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,7 +2082,7 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]:
20822082

20832083
def apply_auto_restatements(
20842084
snapshots: t.Dict[SnapshotId, Snapshot], execution_time: TimeLike
2085-
) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, SnapshotId]]:
2085+
) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, t.List[SnapshotId]]]:
20862086
"""Applies auto restatements to the snapshots.
20872087
20882088
This operation results in the removal of intervals for snapshots that are ready to be restated based
@@ -2097,8 +2097,7 @@ def apply_auto_restatements(
20972097
A list of SnapshotIntervals with **new** intervals that need to be restated.
20982098
"""
20992099
dag = snapshots_to_dag(snapshots.values())
2100-
snapshots_with_auto_restatements: t.List[SnapshotId] = []
2101-
auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {}
2100+
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
21022101
auto_restated_intervals_per_snapshot: t.Dict[SnapshotId, Interval] = {}
21032102
for s_id in dag:
21042103
if s_id not in snapshots:
@@ -2113,6 +2112,7 @@ def apply_auto_restatements(
21132112
for parent_s_id in snapshot.parents
21142113
if parent_s_id in auto_restated_intervals_per_snapshot
21152114
]
2115+
upstream_triggers = []
21162116
if next_auto_restated_interval:
21172117
logger.info(
21182118
"Calculated the next auto restated interval (%s, %s) for snapshot %s",
@@ -2123,21 +2123,15 @@ def apply_auto_restatements(
21232123
auto_restated_intervals.append(next_auto_restated_interval)
21242124

21252125
# auto-restated snapshot is its own trigger
2126-
snapshots_with_auto_restatements.append(s_id)
2127-
auto_restatement_triggers[s_id] = s_id
2128-
else:
2129-
for parent_s_id in snapshot.parents:
2130-
# first auto-restated parent is the trigger
2131-
if parent_s_id in snapshots_with_auto_restatements:
2132-
auto_restatement_triggers[s_id] = parent_s_id
2133-
break
2134-
# if no trigger yet and parent has trigger, inherit their trigger
2135-
# - will be overwritten if a different parent is auto-restated
2136-
if (
2137-
parent_s_id in auto_restatement_triggers
2138-
and s_id not in auto_restatement_triggers
2139-
):
2140-
auto_restatement_triggers[s_id] = auto_restatement_triggers[parent_s_id]
2126+
upstream_triggers = [s_id]
2127+
2128+
for parent_s_id in snapshot.parents:
2129+
if parent_s_id in auto_restatement_triggers:
2130+
upstream_triggers.extend(auto_restatement_triggers[parent_s_id])
2131+
2132+
# remove duplicate triggers
2133+
if upstream_triggers:
2134+
auto_restatement_triggers[s_id] = list(dict.fromkeys(upstream_triggers))
21412135

21422136
if auto_restated_intervals:
21432137
auto_restated_interval_start = sys.maxsize

tests/core/test_snapshot.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,6 +2989,111 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot):
29892989
]
29902990

29912991

2992+
def test_auto_restatement_triggers(make_snapshot):
2993+
model_a = SqlModel(
2994+
name="test_model_a",
2995+
kind=IncrementalByTimeRangeKind(
2996+
time_column=TimeColumn(column="ds"),
2997+
auto_restatement_cron="0 10 * * *",
2998+
auto_restatement_intervals=24,
2999+
),
3000+
start="2020-01-01",
3001+
cron="@daily",
3002+
query=parse_one("SELECT 1 as ds"),
3003+
)
3004+
snapshot_a = make_snapshot(model_a, version="1")
3005+
snapshot_a.add_interval("2020-01-01", "2020-01-05")
3006+
snapshot_a.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00")
3007+
3008+
model_b = SqlModel(
3009+
name="test_model_b",
3010+
kind=IncrementalByTimeRangeKind(
3011+
time_column=TimeColumn(column="ds"),
3012+
),
3013+
start="2020-01-01",
3014+
cron="@daily",
3015+
query=parse_one("SELECT ds FROM test_model_a"),
3016+
)
3017+
snapshot_b = make_snapshot(model_b, nodes={model_a.fqn: model_a}, version="1")
3018+
snapshot_b.add_interval("2020-01-01", "2020-01-05")
3019+
3020+
model_c = SqlModel(
3021+
name="test_model_c",
3022+
kind=IncrementalByTimeRangeKind(
3023+
time_column=TimeColumn(column="ds"),
3024+
auto_restatement_cron="0 10 * * *",
3025+
auto_restatement_intervals=24,
3026+
),
3027+
start="2020-01-01",
3028+
cron="@daily",
3029+
query=parse_one("SELECT ds FROM test_model_a"),
3030+
)
3031+
snapshot_c = make_snapshot(model_c, nodes={model_a.fqn: model_a}, version="1")
3032+
snapshot_c.add_interval("2020-01-01", "2020-01-05")
3033+
snapshot_c.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00")
3034+
3035+
model_d = SqlModel(
3036+
name="test_model_d",
3037+
kind=IncrementalByTimeRangeKind(
3038+
time_column=TimeColumn(column="ds"),
3039+
auto_restatement_cron="0 10 * * *",
3040+
auto_restatement_intervals=24,
3041+
),
3042+
start="2020-01-01",
3043+
cron="@daily",
3044+
query=parse_one("SELECT 1 as ds"),
3045+
)
3046+
snapshot_d = make_snapshot(model_d, version="1")
3047+
snapshot_d.add_interval("2020-01-01", "2020-01-05")
3048+
snapshot_d.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00")
3049+
3050+
model_e = SqlModel(
3051+
name="test_model_e",
3052+
kind=IncrementalByTimeRangeKind(
3053+
time_column=TimeColumn(column="ds"),
3054+
),
3055+
start="2020-01-01",
3056+
cron="@daily",
3057+
query=parse_one(
3058+
"SELECT ds from test_model_b UNION ALL SELECT ds from test_model_c UNION ALL SELECT ds from test_model_d"
3059+
),
3060+
)
3061+
snapshot_e = make_snapshot(
3062+
model_e,
3063+
nodes={
3064+
model_a.fqn: model_a,
3065+
model_b.fqn: model_b,
3066+
model_c.fqn: model_c,
3067+
model_d.fqn: model_d,
3068+
},
3069+
version="1",
3070+
)
3071+
snapshot_e.add_interval("2020-01-01", "2020-01-05")
3072+
3073+
_, auto_restatement_triggers = apply_auto_restatements(
3074+
{
3075+
snapshot_a.snapshot_id: snapshot_a,
3076+
snapshot_b.snapshot_id: snapshot_b,
3077+
snapshot_c.snapshot_id: snapshot_c,
3078+
snapshot_d.snapshot_id: snapshot_d,
3079+
snapshot_e.snapshot_id: snapshot_e,
3080+
},
3081+
"2020-01-06 10:01:00",
3082+
)
3083+
3084+
assert auto_restatement_triggers == {
3085+
snapshot_a.snapshot_id: [snapshot_a.snapshot_id],
3086+
snapshot_d.snapshot_id: [snapshot_d.snapshot_id],
3087+
snapshot_b.snapshot_id: [snapshot_a.snapshot_id],
3088+
snapshot_c.snapshot_id: [snapshot_c.snapshot_id, snapshot_a.snapshot_id],
3089+
snapshot_e.snapshot_id: [
3090+
snapshot_d.snapshot_id,
3091+
snapshot_c.snapshot_id,
3092+
snapshot_a.snapshot_id,
3093+
],
3094+
}
3095+
3096+
29923097
def test_render_signal(make_snapshot, mocker):
29933098
@signal()
29943099
def check_types(batch, env: str, sql: list[SQL], table: exp.Table, default: int = 0):

web/server/console.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def update_snapshot_evaluation_progress(
142142
num_audits_passed: int,
143143
num_audits_failed: int,
144144
audit_only: bool = False,
145-
auto_restatement_trigger: t.Optional[SnapshotId] = None,
145+
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
146146
) -> None:
147147
if audit_only:
148148
return

0 commit comments

Comments
 (0)