Skip to content

Commit 4ff0abb

Browse files
committed
Feat: prevent other processes seeing missing intervals during restatement
1 parent 042e3a1 commit 4ff0abb

File tree

5 files changed

+682
-59
lines changed

5 files changed

+682
-59
lines changed

sqlmesh/core/plan/evaluator.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -284,31 +284,34 @@ def visit_audit_only_run_stage(
284284
def visit_restatement_stage(
285285
self, stage: stages.RestatementStage, plan: EvaluatablePlan
286286
) -> None:
287-
snapshot_intervals_to_restate = {(s, i) for s, i in stage.snapshot_intervals.items()}
288-
289-
# Restating intervals on prod plans should mean that the intervals are cleared across
290-
# all environments, not just the version currently in prod
287+
# Restating intervals on prod plans means that once the data for the intervals being restated has been refreshed
288+
# (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments.
289+
#
291290
# This ensures that work done in dev environments can still be promoted to prod
292-
# by forcing dev environments to re-run intervals that changed in prod
291+
# by forcing dev environments to re-run intervals that changed in prod (because after this stage runs they show as missing)
292+
#
293+
# It also means that any new dev environments created while this restatement plan was running also get the
294+
# correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan
295+
# was created, which could have been several hours ago if there was a lot of data to restate.
293296
#
294297
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
295-
snapshot_intervals_to_restate.update(
296-
{
297-
(s.table_info, s.interval)
298-
for s in identify_restatement_intervals_across_snapshot_versions(
299-
state_reader=self.state_sync,
300-
prod_restatements=plan.restatements,
301-
disable_restatement_models=plan.disabled_restatement_models,
302-
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
303-
current_ts=to_timestamp(plan.execution_time or now()),
304-
).values()
305-
}
306-
)
307-
308-
self.state_sync.remove_intervals(
309-
snapshot_intervals=list(snapshot_intervals_to_restate),
310-
remove_shared_versions=plan.is_prod,
311-
)
298+
snapshot_intervals_to_restate = {
299+
(s.table_info, s.interval)
300+
for s in identify_restatement_intervals_across_snapshot_versions(
301+
state_reader=self.state_sync,
302+
prod_restatements=plan.restatements,
303+
# TODO: we need to ensure that only dev environments are affected
304+
disable_restatement_models=plan.disabled_restatement_models,
305+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
306+
current_ts=to_timestamp(plan.execution_time or now()),
307+
).values()
308+
}
309+
310+
if snapshot_intervals_to_restate:
311+
self.state_sync.remove_intervals(
312+
snapshot_intervals=list(snapshot_intervals_to_restate),
313+
remove_shared_versions=plan.is_prod,
314+
)
312315

313316
def visit_environment_record_update_stage(
314317
self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan

sqlmesh/core/plan/explainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def evaluate(
4141
plan: EvaluatablePlan,
4242
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
4343
) -> None:
44-
plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog)
44+
plan_stages = stages.build_plan_stages(
45+
plan, self.state_reader, self.default_catalog, explain=True
46+
)
4547
explainer_console = _get_explainer_console(
4648
self.console, plan.environment, self.default_catalog
4749
)
@@ -148,9 +150,11 @@ def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree:
148150

149151
def visit_restatement_stage(self, stage: stages.RestatementStage) -> Tree:
150152
tree = Tree("[bold]Invalidate data intervals as part of restatement[/bold]")
151-
for snapshot_table_info, interval in stage.snapshot_intervals.items():
152-
display_name = self._display_name(snapshot_table_info)
153-
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
153+
if snapshot_intervals := stage.snapshot_intervals_to_clear:
154+
for clear_request in snapshot_intervals.values():
155+
display_name = self._display_name(clear_request.table_info)
156+
interval = clear_request.interval
157+
tree.add(f"{display_name} [{to_ts(interval[0])} - {to_ts(interval[1])}]")
154158
return tree
155159

156160
def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree:

sqlmesh/core/plan/stages.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from dataclasses import dataclass
44
from sqlmesh.core import constants as c
55
from sqlmesh.core.environment import EnvironmentStatements, EnvironmentNamingInfo, Environment
6-
from sqlmesh.core.plan.common import should_force_rebuild
6+
from sqlmesh.core.plan.common import (
7+
should_force_rebuild,
8+
identify_restatement_intervals_across_snapshot_versions,
9+
SnapshotIntervalClearRequest,
10+
)
711
from sqlmesh.core.plan.definition import EvaluatablePlan
812
from sqlmesh.core.state_sync import StateReader
913
from sqlmesh.core.scheduler import merged_missing_intervals, SnapshotToIntervals
@@ -12,7 +16,6 @@
1216
Snapshot,
1317
SnapshotTableInfo,
1418
SnapshotId,
15-
Interval,
1619
)
1720

1821

@@ -98,16 +101,22 @@ class AuditOnlyRunStage:
98101

99102
@dataclass
100103
class RestatementStage:
101-
"""Restate intervals for given snapshots.
104+
"""Clear intervals from state for given snapshots.
102105
103106
Args:
104-
snapshot_intervals: Intervals to restate.
107+
snapshot_intervals_to_clear: Intervals to clear from state for snapshots in dev environments, keyed by snapshot name
108+
deployability_index: Deployability of the snapshots in :snapshot_intervals_to_clear, used for
109+
expaining exactly which physical tables are affected. Note that these snapshots are from
110+
other environments that are not present in this plan, so we cant use the plan DeployabilityIndex
105111
all_snapshots: All snapshots in the plan by name.
106112
"""
107113

108-
snapshot_intervals: t.Dict[SnapshotTableInfo, Interval]
109114
all_snapshots: t.Dict[str, Snapshot]
110115

116+
# Only used for --explain so may not be populated
117+
snapshot_intervals_to_clear: t.Optional[t.Dict[str, SnapshotIntervalClearRequest]]
118+
deployability_index: t.Optional[DeployabilityIndex]
119+
111120

112121
@dataclass
113122
class BackfillStage:
@@ -219,15 +228,17 @@ class PlanStagesBuilder:
219228
Args:
220229
state_reader: The state reader to use to read the snapshots and environment.
221230
default_catalog: The default catalog to use for the snapshots.
231+
explain: Whether the stages are bing built for PlanExplainer or for the PlanExecutor.
232+
This allows the ability to conditionally attach additional metadata to each stage that can be helpful for a user,
233+
but not necessarily required in order to execute the plan.
222234
"""
223235

224236
def __init__(
225-
self,
226-
state_reader: StateReader,
227-
default_catalog: t.Optional[str],
237+
self, state_reader: StateReader, default_catalog: t.Optional[str], explain: bool = False
228238
):
229239
self.state_reader = state_reader
230240
self.default_catalog = default_catalog
241+
self.explain = explain
231242

232243
def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
233244
"""Builds the plan stages for the given plan.
@@ -321,10 +332,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
321332
if audit_only_snapshots:
322333
stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values())))
323334

324-
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
325-
if restatement_stage:
326-
stages.append(restatement_stage)
327-
328335
if missing_intervals_before_promote:
329336
stages.append(
330337
BackfillStage(
@@ -349,6 +356,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
349356
)
350357
)
351358

359+
# note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage)
360+
# needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them.
361+
# in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data.
362+
# we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces
363+
# data for existing intervals and does not produce new ones
364+
restatement_stage = self._get_restatement_stage(plan, snapshots_by_name)
365+
if restatement_stage:
366+
stages.append(restatement_stage)
367+
352368
stages.append(
353369
EnvironmentRecordUpdateStage(
354370
no_gaps_snapshot_names={s.name for s in before_promote_snapshots}
@@ -443,15 +459,50 @@ def _get_after_all_stage(
443459
def _get_restatement_stage(
444460
self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot]
445461
) -> t.Optional[RestatementStage]:
446-
snapshot_intervals_to_restate = {}
447-
for name, interval in plan.restatements.items():
448-
restated_snapshot = snapshots_by_name[name]
449-
restated_snapshot.remove_interval(interval)
450-
snapshot_intervals_to_restate[restated_snapshot.table_info] = interval
451-
if not snapshot_intervals_to_restate or plan.is_dev:
462+
if not plan.restatements or plan.is_dev:
463+
# The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod
452464
return None
465+
466+
snapshot_intervals_to_clear: t.Optional[t.Dict[str, SnapshotIntervalClearRequest]] = None
467+
deployability_index: t.Optional[DeployabilityIndex] = None
468+
469+
if self.explain:
470+
# This is re-calculated in the PlanEvaluator because there can be a time lag between:
471+
# - a plan being generated and a plan being executed
472+
# - restatement backfill starting and restatement backfill completing
473+
# During this time, someone may create a new dev environment based on partially restated data.
474+
# So we look up the intervals to clear in other environments at the last minute after restatement has occurred.
475+
#
476+
# However, if the user wants to know what's going on (via --explain), we look them up here as well so we
477+
# can output them to the console as part of the `--explain` output
478+
intervals_to_clear = identify_restatement_intervals_across_snapshot_versions(
479+
state_reader=self.state_reader,
480+
prod_restatements=plan.restatements,
481+
disable_restatement_models=plan.disabled_restatement_models,
482+
loaded_snapshots={s.snapshot_id: s for s in snapshots_by_name.values()},
483+
)
484+
485+
if intervals_to_clear:
486+
snapshot_intervals_to_clear = {
487+
s_id.name: r for s_id, r in intervals_to_clear.items()
488+
}
489+
490+
# creating a deployability index over the "snapshot intervals to clear"
491+
# allows us to print the physical names of the tables affected in the console output
492+
# note that we can't use the DeployabilityIndex on the plan because it only includes
493+
# snapshots for the current environment, not across all environments
494+
deployability_index = DeployabilityIndex.create(
495+
snapshots=self.state_reader.get_snapshots(
496+
[s.snapshot_id for s in snapshot_intervals_to_clear.values()]
497+
),
498+
start=plan.start,
499+
start_override_per_model=plan.start_override_per_model,
500+
)
501+
453502
return RestatementStage(
454-
snapshot_intervals=snapshot_intervals_to_restate, all_snapshots=snapshots_by_name
503+
snapshot_intervals_to_clear=snapshot_intervals_to_clear,
504+
deployability_index=deployability_index,
505+
all_snapshots=snapshots_by_name,
455506
)
456507

457508
def _get_physical_layer_update_stage(
@@ -678,5 +729,6 @@ def build_plan_stages(
678729
plan: EvaluatablePlan,
679730
state_reader: StateReader,
680731
default_catalog: t.Optional[str],
732+
explain: bool = False,
681733
) -> t.List[PlanStage]:
682-
return PlanStagesBuilder(state_reader, default_catalog).build(plan)
734+
return PlanStagesBuilder(state_reader, default_catalog, explain=explain).build(plan)

0 commit comments

Comments
 (0)