33from dataclasses import dataclass
44from sqlmesh .core import constants as c
55from sqlmesh .core .environment import EnvironmentStatements , EnvironmentNamingInfo , Environment
6- from sqlmesh .core .plan .common import should_force_rebuild
6+ from sqlmesh .core .plan .common import (
7+ should_force_rebuild ,
8+ identify_restatement_intervals_across_snapshot_versions ,
9+ SnapshotIntervalClearRequest ,
10+ )
711from sqlmesh .core .plan .definition import EvaluatablePlan
812from sqlmesh .core .state_sync import StateReader
913from sqlmesh .core .scheduler import merged_missing_intervals , SnapshotToIntervals
1216 Snapshot ,
1317 SnapshotTableInfo ,
1418 SnapshotId ,
15- Interval ,
1619)
1720
1821
@@ -98,16 +101,22 @@ class AuditOnlyRunStage:
98101
99102@dataclass
100103class RestatementStage :
101- """Restate intervals for given snapshots.
104+ """Clear intervals from state for given snapshots.
102105
103106 Args:
104- snapshot_intervals: Intervals to restate.
107+ snapshot_intervals_to_clear: Intervals to clear from state for snapshots in dev environments, keyed by snapshot name
108+ deployability_index: Deployability of the snapshots in :snapshot_intervals_to_clear, used for
109+ expaining exactly which physical tables are affected. Note that these snapshots are from
110+ other environments that are not present in this plan, so we cant use the plan DeployabilityIndex
105111 all_snapshots: All snapshots in the plan by name.
106112 """
107113
108- snapshot_intervals : t .Dict [SnapshotTableInfo , Interval ]
109114 all_snapshots : t .Dict [str , Snapshot ]
110115
116+ # Only used for --explain so may not be populated
117+ snapshot_intervals_to_clear : t .Optional [t .Dict [str , SnapshotIntervalClearRequest ]]
118+ deployability_index : t .Optional [DeployabilityIndex ]
119+
111120
112121@dataclass
113122class BackfillStage :
@@ -219,15 +228,17 @@ class PlanStagesBuilder:
219228 Args:
220229 state_reader: The state reader to use to read the snapshots and environment.
221230 default_catalog: The default catalog to use for the snapshots.
231+ explain: Whether the stages are bing built for PlanExplainer or for the PlanExecutor.
232+ This allows the ability to conditionally attach additional metadata to each stage that can be helpful for a user,
233+ but not necessarily required in order to execute the plan.
222234 """
223235
224236 def __init__ (
225- self ,
226- state_reader : StateReader ,
227- default_catalog : t .Optional [str ],
237+ self , state_reader : StateReader , default_catalog : t .Optional [str ], explain : bool = False
228238 ):
229239 self .state_reader = state_reader
230240 self .default_catalog = default_catalog
241+ self .explain = explain
231242
232243 def build (self , plan : EvaluatablePlan ) -> t .List [PlanStage ]:
233244 """Builds the plan stages for the given plan.
@@ -321,10 +332,6 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
321332 if audit_only_snapshots :
322333 stages .append (AuditOnlyRunStage (snapshots = list (audit_only_snapshots .values ())))
323334
324- restatement_stage = self ._get_restatement_stage (plan , snapshots_by_name )
325- if restatement_stage :
326- stages .append (restatement_stage )
327-
328335 if missing_intervals_before_promote :
329336 stages .append (
330337 BackfillStage (
@@ -349,6 +356,15 @@ def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]:
349356 )
350357 )
351358
359+ # note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage)
360+ # needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them.
361+ # in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data.
362+ # we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces
363+ # data for existing intervals and does not produce new ones
364+ restatement_stage = self ._get_restatement_stage (plan , snapshots_by_name )
365+ if restatement_stage :
366+ stages .append (restatement_stage )
367+
352368 stages .append (
353369 EnvironmentRecordUpdateStage (
354370 no_gaps_snapshot_names = {s .name for s in before_promote_snapshots }
@@ -443,15 +459,50 @@ def _get_after_all_stage(
443459 def _get_restatement_stage (
444460 self , plan : EvaluatablePlan , snapshots_by_name : t .Dict [str , Snapshot ]
445461 ) -> t .Optional [RestatementStage ]:
446- snapshot_intervals_to_restate = {}
447- for name , interval in plan .restatements .items ():
448- restated_snapshot = snapshots_by_name [name ]
449- restated_snapshot .remove_interval (interval )
450- snapshot_intervals_to_restate [restated_snapshot .table_info ] = interval
451- if not snapshot_intervals_to_restate or plan .is_dev :
462+ if not plan .restatements or plan .is_dev :
463+ # The RestatementStage to clear intervals from state across all environments is not needed for plans against dev, only prod
452464 return None
465+
466+ snapshot_intervals_to_clear : t .Optional [t .Dict [str , SnapshotIntervalClearRequest ]] = None
467+ deployability_index : t .Optional [DeployabilityIndex ] = None
468+
469+ if self .explain :
470+ # This is re-calculated in the PlanEvaluator because there can be a time lag between:
471+ # - a plan being generated and a plan being executed
472+ # - restatement backfill starting and restatement backfill completing
473+ # During this time, someone may create a new dev environment based on partially restated data.
474+ # So we look up the intervals to clear in other environments at the last minute after restatement has occurred.
475+ #
476+ # However, if the user wants to know what's going on (via --explain), we look them up here as well so we
477+ # can output them to the console as part of the `--explain` output
478+ intervals_to_clear = identify_restatement_intervals_across_snapshot_versions (
479+ state_reader = self .state_reader ,
480+ prod_restatements = plan .restatements ,
481+ disable_restatement_models = plan .disabled_restatement_models ,
482+ loaded_snapshots = {s .snapshot_id : s for s in snapshots_by_name .values ()},
483+ )
484+
485+ if intervals_to_clear :
486+ snapshot_intervals_to_clear = {
487+ s_id .name : r for s_id , r in intervals_to_clear .items ()
488+ }
489+
490+ # creating a deployability index over the "snapshot intervals to clear"
491+ # allows us to print the physical names of the tables affected in the console output
492+ # note that we can't use the DeployabilityIndex on the plan because it only includes
493+ # snapshots for the current environment, not across all environments
494+ deployability_index = DeployabilityIndex .create (
495+ snapshots = self .state_reader .get_snapshots (
496+ [s .snapshot_id for s in snapshot_intervals_to_clear .values ()]
497+ ),
498+ start = plan .start ,
499+ start_override_per_model = plan .start_override_per_model ,
500+ )
501+
453502 return RestatementStage (
454- snapshot_intervals = snapshot_intervals_to_restate , all_snapshots = snapshots_by_name
503+ snapshot_intervals_to_clear = snapshot_intervals_to_clear ,
504+ deployability_index = deployability_index ,
505+ all_snapshots = snapshots_by_name ,
455506 )
456507
457508 def _get_physical_layer_update_stage (
@@ -678,5 +729,6 @@ def build_plan_stages(
678729 plan : EvaluatablePlan ,
679730 state_reader : StateReader ,
680731 default_catalog : t .Optional [str ],
732+ explain : bool = False ,
681733) -> t .List [PlanStage ]:
682- return PlanStagesBuilder (state_reader , default_catalog ).build (plan )
734+ return PlanStagesBuilder (state_reader , default_catalog , explain = explain ).build (plan )
0 commit comments