Skip to content
12 changes: 10 additions & 2 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def update_snapshot_evaluation_progress(
num_audits_passed: int,
num_audits_failed: int,
audit_only: bool = False,
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
) -> None:
"""Updates the snapshot evaluation progress."""

Expand Down Expand Up @@ -575,6 +576,7 @@ def update_snapshot_evaluation_progress(
num_audits_passed: int,
num_audits_failed: int,
audit_only: bool = False,
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
) -> None:
pass

Expand Down Expand Up @@ -1056,6 +1058,7 @@ def update_snapshot_evaluation_progress(
num_audits_passed: int,
num_audits_failed: int,
audit_only: bool = False,
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
) -> None:
"""Update the snapshot evaluation progress."""
if (
Expand Down Expand Up @@ -3639,6 +3642,7 @@ def update_snapshot_evaluation_progress(
num_audits_passed: int,
num_audits_failed: int,
audit_only: bool = False,
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
) -> None:
view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id]

Expand Down Expand Up @@ -3808,11 +3812,15 @@ def update_snapshot_evaluation_progress(
num_audits_passed: int,
num_audits_failed: int,
audit_only: bool = False,
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
) -> None:
message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

if auto_restatement_triggers:
message += f" | auto_restatement_triggers=[{', '.join(trigger.name for trigger in auto_restatement_triggers)}]"

if audit_only:
message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
message = f"Audited {snapshot.name} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

self._write(message)

Expand Down
18 changes: 17 additions & 1 deletion sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ def run_merged_intervals(
selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None,
run_environment_statements: bool = False,
audit_only: bool = False,
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {},
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
"""Runs precomputed batches of missing intervals.

Expand Down Expand Up @@ -531,6 +532,9 @@ def run_node(node: SchedulingUnit) -> None:
evaluation_duration_ms,
num_audits - num_audits_failed,
num_audits_failed,
auto_restatement_triggers=auto_restatement_triggers.get(
snapshot.snapshot_id
),
)
elif isinstance(node, CreateNode):
self.snapshot_evaluator.create_snapshot(
Expand Down Expand Up @@ -736,8 +740,11 @@ def _run_or_audit(
for s_id, interval in (remove_intervals or {}).items():
self.snapshots[s_id].remove_interval(interval)

all_auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
if auto_restatement_enabled:
auto_restated_intervals = apply_auto_restatements(self.snapshots, execution_time)
auto_restated_intervals, all_auto_restatement_triggers = apply_auto_restatements(
self.snapshots, execution_time
)
self.state_sync.add_snapshots_intervals(auto_restated_intervals)
self.state_sync.update_auto_restatements(
{s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()}
Expand All @@ -758,6 +765,14 @@ def _run_or_audit(
if not merged_intervals:
return CompletionStatus.NOTHING_TO_DO

auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
if all_auto_restatement_triggers:
merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals}
auto_restatement_triggers = {
s_id: all_auto_restatement_triggers.get(s_id, [])
for s_id in merged_intervals_snapshots
}

errors, _ = self.run_merged_intervals(
merged_intervals=merged_intervals,
deployability_index=deployability_index,
Expand All @@ -768,6 +783,7 @@ def _run_or_audit(
end=end,
run_environment_statements=run_environment_statements,
audit_only=audit_only,
auto_restatement_triggers=auto_restatement_triggers,
)

return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
Expand Down
48 changes: 32 additions & 16 deletions sqlmesh/core/snapshot/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind
from sqlmesh.core.model.definition import _Model
from sqlmesh.core.node import IntervalUnit, NodeType
from sqlmesh.utils import sanitize_name
from sqlmesh.utils import sanitize_name, unique
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import (
TimeLike,
Expand Down Expand Up @@ -2180,7 +2180,7 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]:

def apply_auto_restatements(
snapshots: t.Dict[SnapshotId, Snapshot], execution_time: TimeLike
) -> t.List[SnapshotIntervals]:
) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, t.List[SnapshotId]]]:
"""Applies auto restatements to the snapshots.

This operation results in the removal of intervals for snapshots that are ready to be restated based
Expand All @@ -2195,6 +2195,7 @@ def apply_auto_restatements(
A list of SnapshotIntervals with **new** intervals that need to be restated.
"""
dag = snapshots_to_dag(snapshots.values())
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
auto_restated_intervals_per_snapshot: t.Dict[SnapshotId, Interval] = {}
for s_id in dag:
if s_id not in snapshots:
Expand All @@ -2209,6 +2210,7 @@ def apply_auto_restatements(
for parent_s_id in snapshot.parents
if parent_s_id in auto_restated_intervals_per_snapshot
]
upstream_triggers = []
if next_auto_restated_interval:
logger.info(
"Calculated the next auto restated interval (%s, %s) for snapshot %s",
Expand All @@ -2218,6 +2220,18 @@ def apply_auto_restatements(
)
auto_restated_intervals.append(next_auto_restated_interval)

# auto-restated snapshot is its own trigger
upstream_triggers = [s_id]
else:
# inherit each parent's auto-restatement triggers (if any)
for parent_s_id in snapshot.parents:
if parent_s_id in auto_restatement_triggers:
upstream_triggers.extend(auto_restatement_triggers[parent_s_id])

# remove duplicate triggers, retaining order and keeping first seen of duplicates
if upstream_triggers:
auto_restatement_triggers[s_id] = unique(upstream_triggers)

if auto_restated_intervals:
auto_restated_interval_start = sys.maxsize
auto_restated_interval_end = -sys.maxsize
Expand Down Expand Up @@ -2247,20 +2261,22 @@ def apply_auto_restatements(

snapshot.apply_pending_restatement_intervals()
snapshot.update_next_auto_restatement_ts(execution_time)

return [
SnapshotIntervals(
name=snapshots[s_id].name,
identifier=None,
version=snapshots[s_id].version,
dev_version=None,
intervals=[],
dev_intervals=[],
pending_restatement_intervals=[interval],
)
for s_id, interval in auto_restated_intervals_per_snapshot.items()
if s_id in snapshots
]
return (
[
SnapshotIntervals(
name=snapshots[s_id].name,
identifier=None,
version=snapshots[s_id].version,
dev_version=None,
intervals=[],
dev_intervals=[],
pending_restatement_intervals=[interval],
)
for s_id, interval in auto_restated_intervals_per_snapshot.items()
if s_id in snapshots
],
auto_restatement_triggers,
)


def parent_snapshots_by_name(
Expand Down
Loading