Skip to content

Commit a3c53ea

Browse files
committed
Add directly modified and restatement triggers
1 parent 0119eef commit a3c53ea

File tree

7 files changed

+137
-25
lines changed

7 files changed

+137
-25
lines changed

sqlmesh/core/console.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3830,12 +3830,10 @@ def update_snapshot_evaluation_progress(
38303830
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
38313831
if snapshot_evaluation_triggers.select_snapshot_triggers:
38323832
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833-
3834-
if snapshot_evaluation_triggers:
3835-
if snapshot_evaluation_triggers.auto_restatement_triggers:
3836-
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3837-
if snapshot_evaluation_triggers.select_snapshot_triggers:
3838-
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833+
if snapshot_evaluation_triggers.directly_modified_triggers:
3834+
message += f" | directly_modified_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.directly_modified_triggers)}"
3835+
if snapshot_evaluation_triggers.restatement_triggers:
3836+
message += f" | restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.restatement_triggers)}"
38393837

38403838
if audit_only:
38413839
message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

sqlmesh/core/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2298,7 +2298,7 @@ def check_intervals(
22982298
if select_models:
22992299
selected, _ = self._select_models_for_run(select_models, True, snapshots.values())
23002300
else:
2301-
selected = t.cast(t.Set[str], snapshots.keys())
2301+
selected = set(snapshots.keys())
23022302

23032303
results = {}
23042304
execution_context = self.execution_context(snapshots=snapshots)

sqlmesh/core/plan/builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def build(self) -> Plan:
289289
else DeployabilityIndex.all_deployable()
290290
)
291291

292-
restatements = self._build_restatements(
292+
restatements, restatement_triggers = self._build_restatements(
293293
dag,
294294
earliest_interval_start(self._context_diff.snapshots.values(), self.execution_time),
295295
)
@@ -326,6 +326,7 @@ def build(self) -> Plan:
326326
indirectly_modified=indirectly_modified,
327327
deployability_index=deployability_index,
328328
restatements=restatements,
329+
restatement_triggers=restatement_triggers,
329330
start_override_per_model=self._start_override_per_model,
330331
end_override_per_model=end_override_per_model,
331332
selected_models_to_backfill=self._backfill_models,
@@ -347,14 +348,14 @@ def _build_dag(self) -> DAG[SnapshotId]:
347348

348349
def _build_restatements(
349350
self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike
350-
) -> t.Dict[SnapshotId, Interval]:
351+
) -> t.Tuple[t.Dict[SnapshotId, Interval], t.Dict[SnapshotId, t.List[SnapshotId]]]:
351352
restate_models = self._restate_models
352353
if restate_models == set():
353354
# This is a warning but we print this as error since the Console is lacking API for warnings.
354355
self._console.log_error(
355356
"Provided restated models do not match any models. No models will be included in plan."
356357
)
357-
return {}
358+
return {}, {}
358359

359360
restatements: t.Dict[SnapshotId, Interval] = {}
360361
forward_only_preview_needed = self._forward_only_preview_needed
@@ -378,7 +379,7 @@ def _build_restatements(
378379
is_preview = True
379380

380381
if not restate_models:
381-
return {}
382+
return {}, {}
382383

383384
start = self._start or earliest_interval_start
384385
end = self._end or now()
@@ -388,6 +389,7 @@ def _build_restatements(
388389
if model_fqn not in self._model_fqn_to_snapshot:
389390
raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.")
390391

392+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
391393
# Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's
392394
# restatement range that it's downstream dependencies all expand their restatement ranges as well.
393395
for s_id in dag:
@@ -423,6 +425,13 @@ def _build_restatements(
423425
logger.info("Skipping restatement for model '%s'", snapshot.name)
424426
continue
425427

428+
if snapshot.name in restate_models:
429+
restatement_triggers[s_id] = [s_id]
430+
if restating_parents:
431+
restatement_triggers[s_id] = restatement_triggers.get(s_id, []) + [
432+
s.snapshot_id for s in restating_parents
433+
]
434+
426435
possible_intervals = {
427436
restatements[p.snapshot_id] for p in restating_parents if p.is_incremental
428437
}
@@ -451,7 +460,7 @@ def _build_restatements(
451460

452461
restatements[s_id] = (snapshot_start, snapshot_end)
453462

454-
return restatements
463+
return restatements, restatement_triggers
455464

456465
def _build_directly_and_indirectly_modified(
457466
self, dag: DAG[SnapshotId]

sqlmesh/core/plan/definition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Plan(PydanticModel, frozen=True):
5757

5858
deployability_index: DeployabilityIndex
5959
restatements: t.Dict[SnapshotId, Interval]
60+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
6061
start_override_per_model: t.Optional[t.Dict[str, datetime]]
6162
end_override_per_model: t.Optional[t.Dict[str, datetime]]
6263

@@ -254,6 +255,7 @@ def to_evaluatable(self) -> EvaluatablePlan:
254255
skip_backfill=self.skip_backfill,
255256
empty_backfill=self.empty_backfill,
256257
restatements={s.name: i for s, i in self.restatements.items()},
258+
restatement_triggers=self.restatement_triggers,
257259
is_dev=self.is_dev,
258260
allow_destructive_models=self.allow_destructive_models,
259261
forward_only=self.forward_only,
@@ -295,6 +297,7 @@ class EvaluatablePlan(PydanticModel):
295297
skip_backfill: bool
296298
empty_backfill: bool
297299
restatements: t.Dict[str, Interval]
300+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
298301
is_dev: bool
299302
allow_destructive_models: t.Set[str]
300303
forward_only: bool

sqlmesh/core/plan/evaluator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
SnapshotCreationFailedError,
3838
SnapshotNameVersion,
3939
)
40+
from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers
4041
from sqlmesh.utils import to_snake_case
4142
from sqlmesh.core.state_sync import StateSync
4243
from sqlmesh.utils import CorrelationId
@@ -234,6 +235,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234235
self.console.log_success("SKIP: No model batches to execute")
235236
return
236237

238+
directly_modified_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
239+
for parent, children in plan.indirectly_modified_snapshots.items():
240+
parent_id = stage.all_snapshots[parent].snapshot_id
241+
directly_modified_triggers[parent_id] = directly_modified_triggers.get(
242+
parent_id, []
243+
) + [parent_id]
244+
for child in children:
245+
directly_modified_triggers[child] = directly_modified_triggers.get(child, []) + [
246+
parent_id
247+
]
248+
directly_modified_triggers = {
249+
k: list(dict.fromkeys(v)) for k, v in directly_modified_triggers.items()
250+
}
251+
snapshot_evaluation_triggers = {
252+
s_id: SnapshotEvaluationTriggers(
253+
directly_modified_triggers=directly_modified_triggers.get(s_id, []),
254+
restatement_triggers=plan.restatement_triggers.get(s_id, []),
255+
)
256+
for s_id in [s.snapshot_id for s in stage.all_snapshots.values()]
257+
}
258+
237259
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
238260
# Convert model name restatements to snapshot ID restatements
239261
restatements_by_snapshot_id = {
@@ -249,6 +271,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
249271
start=plan.start,
250272
end=plan.end,
251273
restatements=restatements_by_snapshot_id,
274+
snapshot_evaluation_triggers=snapshot_evaluation_triggers,
252275
)
253276
if errors:
254277
raise PlanError("Plan application failed.")

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ class SnapshotEvaluationTriggers(PydanticModel):
331331
cron_ready: t.Optional[bool] = None
332332
auto_restatement_triggers: t.List[SnapshotId] = []
333333
select_snapshot_triggers: t.List[SnapshotId] = []
334+
directly_modified_triggers: t.List[SnapshotId] = []
335+
restatement_triggers: t.List[SnapshotId] = []
334336

335337

336338
class SnapshotInfoMixin(ModelKindMixin):

tests/core/test_integration.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
from sqlmesh import CustomMaterialization
29+
import sqlmesh
2930
from sqlmesh.cli.project_init import init_example_project
3031
from sqlmesh.core import constants as c
3132
from sqlmesh.core import dialect as d
@@ -1805,26 +1806,97 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18051806
context, plan = init_and_plan_context("examples/sushi")
18061807
context.apply(plan)
18071808

1809+
# modify 3 models
1810+
# - 2 breaking changes for testing plan directly modified triggers
1811+
# - 1 adding an auto-restatement for subsequent `run` test
1812+
marketing = context.get_model("sushi.marketing")
1813+
marketing_kwargs = {
1814+
**marketing.dict(),
1815+
"query": d.parse_one(
1816+
f"{marketing.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1817+
),
1818+
}
1819+
context.upsert_model(SqlModel.parse_obj(marketing_kwargs))
1820+
1821+
customers = context.get_model("sushi.customers")
1822+
customers_kwargs = {
1823+
**customers.dict(),
1824+
"query": d.parse_one(
1825+
f"{customers.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1826+
),
1827+
}
1828+
context.upsert_model(SqlModel.parse_obj(customers_kwargs))
1829+
18081830
# add auto restatement to orders
1809-
model = context.get_model("sushi.orders")
1810-
kind = {
1811-
**model.kind.dict(),
1831+
orders = context.get_model("sushi.orders")
1832+
orders_kind = {
1833+
**orders.kind.dict(),
18121834
"auto_restatement_cron": "@hourly",
18131835
}
1814-
kwargs = {
1815-
**model.dict(),
1816-
"kind": kind,
1836+
orders_kwargs = {
1837+
**orders.dict(),
1838+
"kind": orders_kind,
18171839
}
1818-
context.upsert_model(PythonModel.parse_obj(kwargs))
1819-
plan = context.plan_builder(skip_tests=True).build()
1820-
context.apply(plan)
1840+
context.upsert_model(PythonModel.parse_obj(orders_kwargs))
18211841

1822-
# Mock run_merged_intervals to capture triggers arg
1823-
scheduler = context.scheduler()
1824-
run_merged_intervals_mock = mocker.patch.object(
1825-
scheduler, "run_merged_intervals", return_value=([], [])
1842+
spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals")
1843+
1844+
context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full())
1845+
1846+
# PLAN: directly modified triggers
1847+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1848+
actual_triggers_name = {
1849+
k.name: sorted([s.name for s in v.directly_modified_triggers])
1850+
for k, v in actual_triggers.items()
1851+
if v.directly_modified_triggers
1852+
}
1853+
marketing_name = '"memory"."sushi"."marketing"'
1854+
customers_name = '"memory"."sushi"."customers"'
1855+
marketing_customers_names = sorted([marketing_name, customers_name])
1856+
children_names = [
1857+
f'"memory"."sushi"."{model}"'
1858+
for model in {
1859+
"waiter_as_customer_by_day",
1860+
"active_customers",
1861+
"count_customers_active",
1862+
"count_customers_inactive",
1863+
}
1864+
]
1865+
assert actual_triggers_name == {
1866+
marketing_name: [marketing_name],
1867+
customers_name: [customers_name],
1868+
**{k: marketing_customers_names for k in children_names},
1869+
}
1870+
1871+
# PLAN: restatement triggers
1872+
spy.reset_mock()
1873+
context.plan(
1874+
restate_models=[
1875+
'"memory"."sushi"."marketing"',
1876+
'"memory"."sushi"."order_items"',
1877+
'"memory"."sushi"."waiter_revenue_by_day"',
1878+
],
1879+
auto_apply=True,
1880+
no_prompts=True,
18261881
)
18271882

1883+
order_items_name = '"memory"."sushi"."order_items"'
1884+
waiter_revenue_by_day_name = '"memory"."sushi"."waiter_revenue_by_day"'
1885+
actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"]
1886+
actual_triggers_name = {
1887+
k.name: sorted([s.name for s in v.restatement_triggers])
1888+
for k, v in actual_triggers.items()
1889+
if v.restatement_triggers
1890+
}
1891+
assert actual_triggers_name == {
1892+
waiter_revenue_by_day_name: [waiter_revenue_by_day_name, order_items_name],
1893+
order_items_name: [order_items_name],
1894+
'"memory"."sushi"."top_waiters"': [waiter_revenue_by_day_name],
1895+
'"memory"."sushi"."customer_revenue_by_day"': [order_items_name],
1896+
'"memory"."sushi"."customer_revenue_lifetime"': [order_items_name],
1897+
}
1898+
1899+
# RUN: select and auto-restatement triggers
18281900
# User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream
18291901
selected_models = {"top_waiters", "waiter_revenue_by_day"}
18301902
selected_models_auto_upstream = {"order_items", "orders", "items"}
@@ -1835,6 +1907,11 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18351907
f'"memory"."sushi"."{model}"' for model in selected_models
18361908
}
18371909

1910+
scheduler = context.scheduler()
1911+
run_merged_intervals_mock = mocker.patch.object(
1912+
scheduler, "run_merged_intervals", return_value=([], [])
1913+
)
1914+
18381915
with time_machine.travel("2023-01-09 00:00:01 UTC"):
18391916
scheduler.run(
18401917
environment=c.PROD,

0 commit comments

Comments
 (0)