Skip to content

Commit 7640ca0

Browse files
committed
PR feedback
1 parent 650dc13 commit 7640ca0

File tree

2 files changed

+68
-41
lines changed

2 files changed

+68
-41
lines changed

sqlmesh/core/plan/explainer.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ def evaluate(
5959

6060
# add extra metadata that's only needed at this point for better --explain output
6161
plan_stages = [
62-
ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan)
62+
ExplainableRestatementStage.from_restatement_stage(
63+
stage,
64+
self.state_reader,
65+
plan,
66+
fetch_full_snapshots=explainer_console.verbosity == Verbosity.VERY_VERBOSE,
67+
)
6368
if isinstance(stage, stages.RestatementStage)
6469
else stage
6570
for stage in plan_stages
@@ -73,6 +78,11 @@ class ExplainerConsole(abc.ABC):
7378
def explain(self, stages: t.List[stages.PlanStage]) -> None:
7479
pass
7580

81+
@property
82+
@abc.abstractmethod
83+
def verbosity(self) -> Verbosity:
84+
pass
85+
7686

7787
@dataclass
7888
class ExplainableRestatementStage(stages.RestatementStage):
@@ -82,7 +92,7 @@ class ExplainableRestatementStage(stages.RestatementStage):
8292
"""
8393

8494
snapshot_intervals_to_clear: t.Dict[
85-
str, t.List[t.Tuple[Snapshot, SnapshotIntervalClearRequest]]
95+
str, t.List[t.Tuple[t.Optional[Snapshot], SnapshotIntervalClearRequest]]
8696
]
8797
"""Which snapshots from other environments would have intervals cleared as part of restatement, grouped by name."""
8898

@@ -92,6 +102,7 @@ def from_restatement_stage(
92102
stage: stages.RestatementStage,
93103
state_reader: StateReader,
94104
plan: EvaluatablePlan,
105+
fetch_full_snapshots: bool = False,
95106
) -> ExplainableRestatementStage:
96107
loaded_snapshots = {s.snapshot_id: s for s in stage.all_snapshots.values()}
97108

@@ -102,17 +113,20 @@ def from_restatement_stage(
102113
loaded_snapshots=loaded_snapshots,
103114
)
104115

105-
# extend loaded_snapshots with the remaining full Snapshot objects from all_restatement_intervals
106-
# so that we can generate physical table names for them while explaining what's going on
107-
remaining_snapshot_ids_to_load = set(all_restatement_intervals).difference(loaded_snapshots)
108-
loaded_snapshots.update(
109-
state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids_to_load)
110-
)
116+
if fetch_full_snapshots:
117+
# extend loaded_snapshots with the remaining full Snapshot objects from all_restatement_intervals
118+
# so that we can generate physical table names for them while explaining what's going on
119+
remaining_snapshot_ids_to_load = set(all_restatement_intervals).difference(
120+
loaded_snapshots
121+
)
122+
loaded_snapshots.update(
123+
state_reader.get_snapshots(snapshot_ids=remaining_snapshot_ids_to_load)
124+
)
111125

112126
snapshot_intervals_to_clear = defaultdict(list)
113127
for snapshot_id, clear_request in all_restatement_intervals.items():
114128
snapshot_intervals_to_clear[clear_request.snapshot.name].append(
115-
(loaded_snapshots[snapshot_id], clear_request)
129+
(loaded_snapshots.get(snapshot_id), clear_request)
116130
)
117131

118132
return cls(
@@ -136,9 +150,13 @@ def __init__(
136150
self.environment_naming_info = environment_naming_info
137151
self.dialect = dialect
138152
self.default_catalog = default_catalog
139-
self.verbosity = verbosity
153+
self._verbosity = verbosity
140154
self.console: RichConsole = console or srich.console
141155

156+
@property
157+
def verbosity(self) -> Verbosity:
158+
return self._verbosity
159+
142160
def explain(self, stages: t.List[stages.PlanStage]) -> None:
143161
tree = Tree("[bold]Explained plan[/bold]")
144162
for stage in stages:
@@ -228,41 +246,49 @@ def visit_restatement_stage(
228246
snapshot_intervals := stage.snapshot_intervals_to_clear
229247
):
230248
for name, requests in snapshot_intervals.items():
249+
if not requests:
250+
# ensure that there is at least one SnapshotIntervalClearRequest in the list
251+
continue
252+
231253
display_name = model_display_name(
232254
name, self.environment_naming_info, self.default_catalog, self.dialect
233255
)
234-
235-
# group by environment for the console output
236-
by_environment: t.Dict[t.Optional[str], t.List[Snapshot]] = defaultdict(list)
237-
238-
interval_start = None
239-
interval_end = None
240-
241-
for snapshot, clear_request in requests:
242-
# used for the top level tree node
243-
interval_start, interval_end = clear_request.interval
244-
245-
if clear_request.sorted_environment_names:
246-
# snapshot is promoted in these environments
247-
for env in clear_request.sorted_environment_names:
248-
by_environment[env].append(snapshot)
249-
else:
250-
# snapshot is not currently promoted in any environment
251-
by_environment[None].append(snapshot)
252-
253-
if not interval_start or not interval_end:
254-
continue
255-
256+
_, clear_request = requests[0]
257+
interval_start, interval_end = clear_request.interval
256258
node = tree.add(f"{display_name} [{to_ts(interval_start)} - {to_ts(interval_end)}]")
257259

258-
for env_name, snapshots_to_clear in by_environment.items():
259-
env_name = env_name or "(no env)"
260-
for snapshot in snapshots_to_clear:
261-
# note: we dont need a DeployabilityIndex and can just hardcode is_deployable=True.
262-
# The reason is that non-deployable data can never be restated so we only need to
263-
# bother clearing intervals for the deployable version of the table
264-
physical_table_name = snapshot.table_name(True)
265-
node.add(f"{env_name} -> {physical_table_name}")
260+
if not self.verbosity.is_very_verbose:
261+
# In normal mode we just indicate which environments are affected at a high level
262+
all_environment_names = sorted(
263+
set(env_name for (_, cr) in requests for env_name in cr.environment_names)
264+
)
265+
node.add("in environments: " + ", ".join(all_environment_names))
266+
else:
267+
# In "very verbose" mode, we print all the affected physical tables
268+
269+
# group by environment for the console output
270+
by_environment: t.Dict[t.Optional[str], t.List[Snapshot]] = defaultdict(list)
271+
272+
for snapshot, clear_request in requests:
273+
if not snapshot:
274+
# Snapshots are None (not loaded) unless
275+
continue
276+
if clear_request.sorted_environment_names:
277+
# snapshot is promoted in these environments
278+
for env in clear_request.sorted_environment_names:
279+
by_environment[env].append(snapshot)
280+
else:
281+
# snapshot is not currently promoted in any environment
282+
by_environment[None].append(snapshot)
283+
284+
for env_name, snapshots_to_clear in by_environment.items():
285+
env_name = env_name or "(no env)"
286+
for snapshot in snapshots_to_clear:
287+
# note: we dont need a DeployabilityIndex and can just hardcode is_deployable=True.
288+
# The reason is that non-deployable data can never be restated so we only need to
289+
# bother clearing intervals for the deployable version of the table
290+
physical_table_name = snapshot.table_name(True)
291+
node.add(f"{env_name} -> {physical_table_name}")
266292

267293
return tree
268294

tests/core/test_plan_stages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]):
765765
restatement_stage = stages[2]
766766
assert isinstance(restatement_stage, RestatementStage)
767767
restatement_stage = ExplainableRestatementStage.from_restatement_stage(
768-
restatement_stage, state_reader, plan
768+
restatement_stage, state_reader, plan, fetch_full_snapshots=True
769769
)
770770

771771
# note: we only clear the intervals from state for "a" in dev, we leave prod alone
@@ -774,6 +774,7 @@ def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]):
774774
snapshot_name, clear_requests = list(restatement_stage.snapshot_intervals_to_clear.items())[0]
775775
assert len(clear_requests) == 1
776776
full_snapshot, clear_request = clear_requests[0]
777+
assert full_snapshot is not None
777778
assert isinstance(clear_request, SnapshotIntervalClearRequest)
778779
assert snapshot_name == '"a"'
779780
assert full_snapshot.name == snapshot_name

0 commit comments

Comments
 (0)