Skip to content

Commit f5b3abe

Browse files
committed
Add directly modified and restatement triggers
1 parent d5f9d24 commit f5b3abe

File tree

7 files changed

+137
-25
lines changed

7 files changed

+137
-25
lines changed

sqlmesh/core/console.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3830,12 +3830,10 @@ def update_snapshot_evaluation_progress(
38303830
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
38313831
if snapshot_evaluation_triggers.select_snapshot_triggers:
38323832
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833-
3834-
if snapshot_evaluation_triggers:
3835-
if snapshot_evaluation_triggers.auto_restatement_triggers:
3836-
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3837-
if snapshot_evaluation_triggers.select_snapshot_triggers:
3838-
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833+
if snapshot_evaluation_triggers.directly_modified_triggers:
3834+
message += f" | directly_modified_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.directly_modified_triggers)}"
3835+
if snapshot_evaluation_triggers.restatement_triggers:
3836+
message += f" | restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.restatement_triggers)}"
38393837

38403838
if audit_only:
38413839
message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

sqlmesh/core/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2298,7 +2298,7 @@ def check_intervals(
22982298
if select_models:
22992299
selected, _ = self._select_models_for_run(select_models, True, snapshots.values())
23002300
else:
2301-
selected = t.cast(t.Set[str], snapshots.keys())
2301+
selected = set(snapshots.keys())
23022302

23032303
results = {}
23042304
execution_context = self.execution_context(snapshots=snapshots)

sqlmesh/core/plan/builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def build(self) -> Plan:
293293
else DeployabilityIndex.all_deployable()
294294
)
295295

296-
restatements = self._build_restatements(
296+
restatements, restatement_triggers = self._build_restatements(
297297
dag,
298298
earliest_interval_start(self._context_diff.snapshots.values(), self.execution_time),
299299
)
@@ -330,6 +330,7 @@ def build(self) -> Plan:
330330
indirectly_modified=indirectly_modified,
331331
deployability_index=deployability_index,
332332
restatements=restatements,
333+
restatement_triggers=restatement_triggers,
333334
start_override_per_model=self._start_override_per_model,
334335
end_override_per_model=end_override_per_model,
335336
selected_models_to_backfill=self._backfill_models,
@@ -351,14 +352,14 @@ def _build_dag(self) -> DAG[SnapshotId]:
351352

352353
def _build_restatements(
353354
self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike
354-
) -> t.Dict[SnapshotId, Interval]:
355+
) -> t.Tuple[t.Dict[SnapshotId, Interval], t.Dict[SnapshotId, t.List[SnapshotId]]]:
355356
restate_models = self._restate_models
356357
if restate_models == set():
357358
# This is a warning but we print this as error since the Console is lacking API for warnings.
358359
self._console.log_error(
359360
"Provided restated models do not match any models. No models will be included in plan."
360361
)
361-
return {}
362+
return {}, {}
362363

363364
restatements: t.Dict[SnapshotId, Interval] = {}
364365
forward_only_preview_needed = self._forward_only_preview_needed
@@ -380,7 +381,7 @@ def _build_restatements(
380381
is_preview = True
381382

382383
if not restate_models:
383-
return {}
384+
return {}, {}
384385

385386
start = self._start or earliest_interval_start
386387
end = self._end or now()
@@ -390,6 +391,7 @@ def _build_restatements(
390391
if model_fqn not in self._model_fqn_to_snapshot:
391392
raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.")
392393

394+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
393395
# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
394396
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
395397
for s_id in dag:
@@ -422,6 +424,13 @@ def _build_restatements(
422424
logger.info("Skipping restatement for model '%s'", snapshot.name)
423425
continue
424426

427+
if snapshot.name in restate_models:
428+
restatement_triggers[s_id] = [s_id]
429+
if restating_parents:
430+
restatement_triggers[s_id] = restatement_triggers.get(s_id, []) + [
431+
s.snapshot_id for s in restating_parents
432+
]
433+
425434
possible_intervals = {
426435
restatements[p.snapshot_id] for p in restating_parents if p.is_incremental
427436
}
@@ -450,7 +459,7 @@ def _build_restatements(
450459

451460
restatements[s_id] = (snapshot_start, snapshot_end)
452461

453-
return restatements
462+
return restatements, restatement_triggers
454463

455464
def _build_directly_and_indirectly_modified(
456465
self, dag: DAG[SnapshotId]

sqlmesh/core/plan/definition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Plan(PydanticModel, frozen=True):
5757

5858
deployability_index: DeployabilityIndex
5959
restatements: t.Dict[SnapshotId, Interval]
60+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
6061
start_override_per_model: t.Optional[t.Dict[str, datetime]]
6162
end_override_per_model: t.Optional[t.Dict[str, datetime]]
6263

@@ -254,6 +255,7 @@ def to_evaluatable(self) -> EvaluatablePlan:
254255
skip_backfill=self.skip_backfill,
255256
empty_backfill=self.empty_backfill,
256257
restatements={s.name: i for s, i in self.restatements.items()},
258+
restatement_triggers=self.restatement_triggers,
257259
is_dev=self.is_dev,
258260
allow_destructive_models=self.allow_destructive_models,
259261
forward_only=self.forward_only,
@@ -295,6 +297,7 @@ class EvaluatablePlan(PydanticModel):
295297
skip_backfill: bool
296298
empty_backfill: bool
297299
restatements: t.Dict[str, Interval]
300+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
298301
is_dev: bool
299302
allow_destructive_models: t.Set[str]
300303
forward_only: bool

sqlmesh/core/plan/evaluator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
SnapshotCreationFailedError,
3838
SnapshotNameVersion,
3939
)
40+
from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers
4041
from sqlmesh.utils import to_snake_case
4142
from sqlmesh.core.state_sync import StateSync
4243
from sqlmesh.utils import CorrelationId
@@ -234,6 +235,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234235
self.console.log_success("SKIP: No model batches to execute")
235236
return
236237

238+
directly_modified_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
239+
for parent, children in plan.indirectly_modified_snapshots.items():
240+
parent_id = stage.all_snapshots[parent].snapshot_id
241+
directly_modified_triggers[parent_id] = directly_modified_triggers.get(
242+
parent_id, []
243+
) + [parent_id]
244+
for child in children:
245+
directly_modified_triggers[child] = directly_modified_triggers.get(child, []) + [
246+
parent_id
247+
]
248+
directly_modified_triggers = {
249+
k: list(dict.fromkeys(v)) for k, v in directly_modified_triggers.items()
250+
}
251+
snapshot_evaluation_triggers = {
252+
s_id: SnapshotEvaluationTriggers(
253+
directly_modified_triggers=directly_modified_triggers.get(s_id, []),
254+
restatement_triggers=plan.restatement_triggers.get(s_id, []),
255+
)
256+
for s_id in [s.snapshot_id for s in stage.all_snapshots.values()]
257+
}
258+
237259
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
238260
# Convert model name restatements to snapshot ID restatements
239261
restatements_by_snapshot_id = {
@@ -249,6 +271,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
249271
start=plan.start,
250272
end=plan.end,
251273
restatements=restatements_by_snapshot_id,
274+
snapshot_evaluation_triggers=snapshot_evaluation_triggers,
252275
)
253276
if errors:
254277
raise PlanError("Plan application failed.")

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ class SnapshotEvaluationTriggers(PydanticModel):
330330
cron_ready: t.Optional[bool] = None
331331
auto_restatement_triggers: t.List[SnapshotId] = []
332332
select_snapshot_triggers: t.List[SnapshotId] = []
333+
directly_modified_triggers: t.List[SnapshotId] = []
334+
restatement_triggers: t.List[SnapshotId] = []
333335

334336

335337
class SnapshotInfoMixin(ModelKindMixin):

tests/core/test_integration.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
from sqlmesh import CustomMaterialization
29+
import sqlmesh
2930
from sqlmesh.cli.project_init import init_example_project
3031
from sqlmesh.core import constants as c
3132
from sqlmesh.core import dialect as d
@@ -1813,26 +1814,97 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18131814
context, plan = init_and_plan_context("examples/sushi")
18141815
context.apply(plan)
18151816

1817+
# modify 3 models
1818+
# - 2 breaking changes for testing plan directly modified triggers
1819+
# - 1 adding an auto-restatement for subsequent `run` test
1820+
marketing = context.get_model("sushi.marketing")
1821+
marketing_kwargs = {
1822+
**marketing.dict(),
1823+
"query": d.parse_one(
1824+
f"{marketing.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1825+
),
1826+
}
1827+
context.upsert_model(SqlModel.parse_obj(marketing_kwargs))
1828+
1829+
customers = context.get_model("sushi.customers")
1830+
customers_kwargs = {
1831+
**customers.dict(),
1832+
"query": d.parse_one(
1833+
f"{customers.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1834+
),
1835+
}
1836+
context.upsert_model(SqlModel.parse_obj(customers_kwargs))
1837+
18161838
# add auto restatement to orders
1817-
model = context.get_model("sushi.orders")
1818-
kind = {
1819-
**model.kind.dict(),
1839+
orders = context.get_model("sushi.orders")
1840+
orders_kind = {
1841+
**orders.kind.dict(),
18201842
"auto_restatement_cron": "@hourly",
18211843
}
1822-
kwargs = {
1823-
**model.dict(),
1824-
"kind": kind,
1844+
orders_kwargs = {
1845+
**orders.dict(),
1846+
"kind": orders_kind,
18251847
}
1826-
context.upsert_model(PythonModel.parse_obj(kwargs))
1827-
plan = context.plan_builder(skip_tests=True).build()
1828-
context.apply(plan)
1848+
context.upsert_model(PythonModel.parse_obj(orders_kwargs))
18291849

1830-
# Mock run_merged_intervals to capture triggers arg
1831-
scheduler = context.scheduler()
1832-
run_merged_intervals_mock = mocker.patch.object(
1833-
scheduler, "run_merged_intervals", return_value=([], [])
1850+
spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals")
1851+
1852+
context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full())
1853+
1854+
# PLAN: directly modified triggers
1855+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1856+
actual_triggers_name = {
1857+
k.name: sorted([s.name for s in v.directly_modified_triggers])
1858+
for k, v in actual_triggers.items()
1859+
if v.directly_modified_triggers
1860+
}
1861+
marketing_name = '"memory"."sushi"."marketing"'
1862+
customers_name = '"memory"."sushi"."customers"'
1863+
marketing_customers_names = sorted([marketing_name, customers_name])
1864+
children_names = [
1865+
f'"memory"."sushi"."{model}"'
1866+
for model in {
1867+
"waiter_as_customer_by_day",
1868+
"active_customers",
1869+
"count_customers_active",
1870+
"count_customers_inactive",
1871+
}
1872+
]
1873+
assert actual_triggers_name == {
1874+
marketing_name: [marketing_name],
1875+
customers_name: [customers_name],
1876+
**{k: marketing_customers_names for k in children_names},
1877+
}
1878+
1879+
# PLAN: restatement triggers
1880+
spy.reset_mock()
1881+
context.plan(
1882+
restate_models=[
1883+
'"memory"."sushi"."marketing"',
1884+
'"memory"."sushi"."order_items"',
1885+
'"memory"."sushi"."waiter_revenue_by_day"',
1886+
],
1887+
auto_apply=True,
1888+
no_prompts=True,
18341889
)
18351890

1891+
order_items_name = '"memory"."sushi"."order_items"'
1892+
waiter_revenue_by_day_name = '"memory"."sushi"."waiter_revenue_by_day"'
1893+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1894+
actual_triggers_name = {
1895+
k.name: sorted([s.name for s in v.restatement_triggers])
1896+
for k, v in actual_triggers.items()
1897+
if v.restatement_triggers
1898+
}
1899+
assert actual_triggers_name == {
1900+
waiter_revenue_by_day_name: [waiter_revenue_by_day_name, order_items_name],
1901+
order_items_name: [order_items_name],
1902+
'"memory"."sushi"."top_waiters"': [waiter_revenue_by_day_name],
1903+
'"memory"."sushi"."customer_revenue_by_day"': [order_items_name],
1904+
'"memory"."sushi"."customer_revenue_lifetime"': [order_items_name],
1905+
}
1906+
1907+
# RUN: select and auto-restatement triggers
18361908
# User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream
18371909
selected_models = {"top_waiters", "waiter_revenue_by_day"}
18381910
selected_models_auto_upstream = {"order_items", "orders", "items"}
@@ -1843,6 +1915,11 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18431915
f'"memory"."sushi"."{model}"' for model in selected_models
18441916
}
18451917

1918+
scheduler = context.scheduler()
1919+
run_merged_intervals_mock = mocker.patch.object(
1920+
scheduler, "run_merged_intervals", return_value=([], [])
1921+
)
1922+
18461923
with time_machine.travel("2023-01-09 00:00:01 UTC"):
18471924
scheduler.run(
18481925
environment=c.PROD,

0 commit comments

Comments
 (0)