Skip to content

Commit 06c11df

Browse files
committed
Fix: When restating prod, clear intervals across all related snapshots, not just promoted ones
1 parent 4c42b45 commit 06c11df

File tree

3 files changed

+376
-102
lines changed

3 files changed

+376
-102
lines changed

sqlmesh/core/plan/common.py

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from __future__ import annotations
2+
import typing as t
3+
import logging
4+
from dataclasses import dataclass, field
25

3-
from sqlmesh.core.snapshot import Snapshot
6+
from sqlmesh.core.state_sync import StateReader
7+
from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo, SnapshotNameVersion
8+
from sqlmesh.core.snapshot.definition import Interval
9+
from sqlmesh.utils.dag import DAG
10+
from sqlmesh.utils.date import now_timestamp
11+
12+
logger = logging.getLogger(__name__)
413

514

615
def should_force_rebuild(old: Snapshot, new: Snapshot) -> bool:
@@ -27,3 +36,182 @@ def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
2736
# If the partitioning hasn't changed, then we don't need to rebuild
2837
return False
2938
return True
39+
40+
41+
@dataclass
42+
class SnapshotIntervalClearRequest:
43+
# affected snapshot
44+
table_info: SnapshotTableInfo
45+
46+
# which interval to clear
47+
interval: Interval
48+
49+
# which environments this snapshot is currently promoted
50+
# note that this can be empty if the snapshot exists because its ttl has not expired
51+
# but it is not part of any particular environment
52+
environment_names: t.Set[str] = field(default_factory=set)
53+
54+
@property
55+
def snapshot_id(self) -> SnapshotId:
56+
return self.table_info.snapshot_id
57+
58+
@property
59+
def sorted_environment_names(self) -> t.List[str]:
60+
return list(sorted(self.environment_names))
61+
62+
63+
def identify_restatement_intervals_across_snapshot_versions(
64+
state_reader: StateReader,
65+
prod_restatements: t.Dict[str, Interval],
66+
disable_restatement_models: t.Set[str],
67+
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
68+
current_ts: t.Optional[int] = None,
69+
) -> t.Dict[SnapshotId, SnapshotIntervalClearRequest]:
70+
"""
71+
Given a map of snapshot names + intervals to restate in prod:
72+
- Look up matching snapshots (match based on name - regardless of version, to get all versions)
73+
- For each match, also match downstream snapshots in each dev environment while filtering out models that have restatement disabled
74+
- Return a list of all snapshots that are affected + the interval that needs to be cleared for each
75+
76+
The goal here is to produce a list of intervals to invalidate across all dev snapshots so that a subsequent plan or
77+
cadence run in those environments causes the intervals to be repopulated.
78+
"""
79+
if not prod_restatements:
80+
return {}
81+
82+
# Although :loaded_snapshots is sourced from RestatementStage.all_snapshots, since the only time we ever need
83+
# to clear intervals across all environments is for prod, the :loaded_snapshots here are always from prod
84+
prod_name_versions: t.Set[SnapshotNameVersion] = {
85+
s.name_version for s in loaded_snapshots.values()
86+
}
87+
88+
snapshot_intervals_to_clear: t.Dict[SnapshotId, SnapshotIntervalClearRequest] = {}
89+
90+
for env_summary in state_reader.get_environments_summary():
91+
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
92+
env = state_reader.get_environment(env_summary.name)
93+
if not env:
94+
logger.warning("Environment %s not found", env_summary.name)
95+
continue
96+
97+
snapshots_by_name = {s.name: s.table_info for s in env.snapshots}
98+
99+
# We dont just restate matching snapshots, we also have to restate anything downstream of them
100+
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
101+
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})
102+
103+
for restate_snapshot_name, interval in prod_restatements.items():
104+
if restate_snapshot_name not in snapshots_by_name:
105+
# snapshot is not promoted in this environment
106+
continue
107+
108+
affected_snapshot_names = [
109+
x
110+
for x in ([restate_snapshot_name] + env_dag.downstream(restate_snapshot_name))
111+
if x not in disable_restatement_models
112+
]
113+
114+
for affected_snapshot_name in affected_snapshot_names:
115+
affected_snapshot = snapshots_by_name[affected_snapshot_name]
116+
117+
# Don't clear intervals for a dev snapshot if it shares the same physical version with prod.
118+
# Otherwise, prod will be affected by what should be a dev operation
119+
if affected_snapshot.name_version in prod_name_versions:
120+
continue
121+
122+
clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id)
123+
if not clear_request:
124+
clear_request = SnapshotIntervalClearRequest(
125+
table_info=affected_snapshot, interval=interval
126+
)
127+
snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request
128+
129+
clear_request.environment_names |= set([env.name])
130+
131+
# snapshot_intervals_to_clear now contains the entire hierarchy of affected snapshots based
132+
# on building the DAG for each environment and including downstream snapshots
133+
# but, what if there are affected snapshots that arent part of any environment?
134+
unique_snapshot_names = set(snapshot_id.name for snapshot_id in snapshot_intervals_to_clear)
135+
136+
current_ts = current_ts or now_timestamp()
137+
all_matching_non_prod_snapshots = {
138+
s.snapshot_id: s
139+
for s in state_reader.get_snapshots_by_names(
140+
snapshot_names=unique_snapshot_names, current_ts=current_ts, exclude_expired=True
141+
)
142+
# Don't clear intervals for a snapshot if it shares the same physical version with prod.
143+
# Otherwise, prod will be affected by what should be a dev operation
144+
if s.name_version not in prod_name_versions
145+
}
146+
147+
# identify the ones that we havent picked up yet, which are the ones that dont exist in any environment
148+
if remaining_snapshot_ids := set(all_matching_non_prod_snapshots).difference(
149+
snapshot_intervals_to_clear
150+
):
151+
# these snapshot id's exist in isolation and may be related to a downstream dependency of the :prod_restatements,
152+
# rather than directly related, so we can't simply look up the interval to clear based on :prod_restatements.
153+
# To figure out the interval that should be cleared, we can match to the existing list based on name
154+
# and conservatively take the widest interval that shows up
155+
snapshot_name_to_widest_interval: t.Dict[str, Interval] = {}
156+
for s_id, clear_request in snapshot_intervals_to_clear.items():
157+
current_start, current_end = snapshot_name_to_widest_interval.get(
158+
s_id.name, clear_request.interval
159+
)
160+
next_start, next_end = clear_request.interval
161+
162+
next_start = min(current_start, next_start)
163+
next_end = max(current_end, next_end)
164+
165+
snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end)
166+
167+
# we need to fetch full Snapshot's to get access to the SnapshotTableInfo objects
168+
# required by StateSync.remove_intervals()
169+
# but at this point we have minimized the list by excluding the ones that are already present in prod
170+
# and also excluding the ones we have already matched earlier while traversing the environment DAGs
171+
remaining_snapshots = state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids)
172+
for remaining_snapshot_id, remaining_snapshot in remaining_snapshots.items():
173+
snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest(
174+
table_info=remaining_snapshot.table_info,
175+
interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name],
176+
)
177+
178+
loaded_snapshots.update(remaining_snapshots)
179+
180+
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
181+
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
182+
# so we only do it if necessary
183+
full_history_restatement_snapshot_ids = [
184+
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
185+
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
186+
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
187+
# So for now, these are not considered
188+
s_id
189+
for s_id, s in snapshot_intervals_to_clear.items()
190+
if s.table_info.full_history_restatement_only
191+
]
192+
if full_history_restatement_snapshot_ids:
193+
# only load full snapshot records that we havent already loaded
194+
additional_snapshots = state_reader.get_snapshots(
195+
[
196+
s.snapshot_id
197+
for s in full_history_restatement_snapshot_ids
198+
if s.snapshot_id not in loaded_snapshots
199+
]
200+
)
201+
202+
all_snapshots = loaded_snapshots | additional_snapshots
203+
204+
for full_snapshot_id in full_history_restatement_snapshot_ids:
205+
full_snapshot = all_snapshots[full_snapshot_id]
206+
intervals_to_clear = snapshot_intervals_to_clear[full_snapshot_id]
207+
208+
original_start, original_end = intervals_to_clear.interval
209+
210+
# get_removal_interval() widens intervals if necessary
211+
new_interval = full_snapshot.get_removal_interval(
212+
start=original_start, end=original_end
213+
)
214+
215+
intervals_to_clear.interval = new_interval
216+
217+
return snapshot_intervals_to_clear

sqlmesh/core/plan/evaluator.py

Lines changed: 13 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sqlmesh.core.console import Console, get_console
2323
from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements
2424
from sqlmesh.core.macros import RuntimeStage
25-
from sqlmesh.core.snapshot.definition import Interval, to_view_mapping
25+
from sqlmesh.core.snapshot.definition import to_view_mapping
2626
from sqlmesh.core.plan import stages
2727
from sqlmesh.core.plan.definition import EvaluatablePlan
2828
from sqlmesh.core.scheduler import Scheduler
@@ -33,17 +33,15 @@
3333
SnapshotIntervals,
3434
SnapshotId,
3535
SnapshotInfoLike,
36-
SnapshotTableInfo,
3736
SnapshotCreationFailedError,
38-
SnapshotNameVersion,
3937
)
4038
from sqlmesh.utils import to_snake_case
4139
from sqlmesh.core.state_sync import StateSync
40+
from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions
4241
from sqlmesh.utils import CorrelationId
4342
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4443
from sqlmesh.utils.errors import PlanError, SQLMeshError
45-
from sqlmesh.utils.dag import DAG
46-
from sqlmesh.utils.date import now
44+
from sqlmesh.utils.date import now, to_timestamp
4745

4846
logger = logging.getLogger(__name__)
4947

@@ -295,11 +293,16 @@ def visit_restatement_stage(
295293
#
296294
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
297295
snapshot_intervals_to_restate.update(
298-
self._restatement_intervals_across_all_environments(
299-
prod_restatements=plan.restatements,
300-
disable_restatement_models=plan.disabled_restatement_models,
301-
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
302-
)
296+
{
297+
(s.table_info, s.interval)
298+
for s in identify_restatement_intervals_across_snapshot_versions(
299+
state_reader=self.state_sync,
300+
prod_restatements=plan.restatements,
301+
disable_restatement_models=plan.disabled_restatement_models,
302+
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
303+
current_ts=to_timestamp(plan.execution_time or now()),
304+
).values()
305+
}
303306
)
304307

305308
self.state_sync.remove_intervals(
@@ -419,97 +422,6 @@ def _demote_snapshots(
419422
on_complete=on_complete,
420423
)
421424

422-
def _restatement_intervals_across_all_environments(
423-
self,
424-
prod_restatements: t.Dict[str, Interval],
425-
disable_restatement_models: t.Set[str],
426-
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
427-
) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]:
428-
"""
429-
Given a map of snapshot names + intervals to restate in prod:
430-
- Look up matching snapshots across all environments (match based on name - regardless of version)
431-
- For each match, also match downstream snapshots while filtering out models that have restatement disabled
432-
- Return all matches mapped to the intervals of the prod snapshot being restated
433-
434-
The goal here is to produce a list of intervals to invalidate across all environments so that a cadence
435-
run in those environments causes the intervals to be repopulated
436-
"""
437-
if not prod_restatements:
438-
return set()
439-
440-
prod_name_versions: t.Set[SnapshotNameVersion] = {
441-
s.name_version for s in loaded_snapshots.values()
442-
}
443-
444-
snapshots_to_restate: t.Dict[SnapshotId, t.Tuple[SnapshotTableInfo, Interval]] = {}
445-
446-
for env_summary in self.state_sync.get_environments_summary():
447-
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
448-
env = self.state_sync.get_environment(env_summary.name)
449-
if not env:
450-
logger.warning("Environment %s not found", env_summary.name)
451-
continue
452-
453-
keyed_snapshots = {s.name: s.table_info for s in env.snapshots}
454-
455-
# We dont just restate matching snapshots, we also have to restate anything downstream of them
456-
# so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev
457-
env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots})
458-
459-
for restatement, intervals in prod_restatements.items():
460-
if restatement not in keyed_snapshots:
461-
continue
462-
affected_snapshot_names = [
463-
x
464-
for x in ([restatement] + env_dag.downstream(restatement))
465-
if x not in disable_restatement_models
466-
]
467-
snapshots_to_restate.update(
468-
{
469-
keyed_snapshots[a].snapshot_id: (keyed_snapshots[a], intervals)
470-
for a in affected_snapshot_names
471-
# Don't restate a snapshot if it shares the version with a snapshot in prod
472-
if keyed_snapshots[a].name_version not in prod_name_versions
473-
}
474-
)
475-
476-
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
477-
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
478-
# so we only do it if necessary
479-
full_history_restatement_snapshot_ids = [
480-
# FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self
481-
# however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state,
482-
# is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present)
483-
# So for now, these are not considered
484-
s_id
485-
for s_id, s in snapshots_to_restate.items()
486-
if s[0].full_history_restatement_only
487-
]
488-
if full_history_restatement_snapshot_ids:
489-
# only load full snapshot records that we havent already loaded
490-
additional_snapshots = self.state_sync.get_snapshots(
491-
[
492-
s.snapshot_id
493-
for s in full_history_restatement_snapshot_ids
494-
if s.snapshot_id not in loaded_snapshots
495-
]
496-
)
497-
498-
all_snapshots = loaded_snapshots | additional_snapshots
499-
500-
for full_snapshot_id in full_history_restatement_snapshot_ids:
501-
full_snapshot = all_snapshots[full_snapshot_id]
502-
_, original_intervals = snapshots_to_restate[full_snapshot_id]
503-
original_start, original_end = original_intervals
504-
505-
# get_removal_interval() widens intervals if necessary
506-
new_intervals = full_snapshot.get_removal_interval(
507-
start=original_start, end=original_end
508-
)
509-
snapshots_to_restate[full_snapshot_id] = (full_snapshot.table_info, new_intervals)
510-
511-
return set(snapshots_to_restate.values())
512-
513425
def _update_intervals_for_new_snapshots(self, snapshots: t.Collection[Snapshot]) -> None:
514426
snapshots_intervals: t.List[SnapshotIntervals] = []
515427
for snapshot in snapshots:

0 commit comments

Comments
 (0)