Skip to content

Commit 9b26320

Browse files
Fix: Don't use SCD type 2 restatement logic in regular runs (#4976)
1 parent 39caa0f commit 9b26320

File tree

8 files changed

+210
-7
lines changed

8 files changed

+210
-7
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1462,6 +1462,7 @@ def scd_type_2_by_time(
14621462
table_description: t.Optional[str] = None,
14631463
column_descriptions: t.Optional[t.Dict[str, str]] = None,
14641464
truncate: bool = False,
1465+
is_restatement: bool = False,
14651466
**kwargs: t.Any,
14661467
) -> None:
14671468
self._scd_type_2(
@@ -1478,6 +1479,7 @@ def scd_type_2_by_time(
14781479
table_description=table_description,
14791480
column_descriptions=column_descriptions,
14801481
truncate=truncate,
1482+
is_restatement=is_restatement,
14811483
**kwargs,
14821484
)
14831485

@@ -1496,6 +1498,7 @@ def scd_type_2_by_column(
14961498
table_description: t.Optional[str] = None,
14971499
column_descriptions: t.Optional[t.Dict[str, str]] = None,
14981500
truncate: bool = False,
1501+
is_restatement: bool = False,
14991502
**kwargs: t.Any,
15001503
) -> None:
15011504
self._scd_type_2(
@@ -1512,6 +1515,7 @@ def scd_type_2_by_column(
15121515
table_description=table_description,
15131516
column_descriptions=column_descriptions,
15141517
truncate=truncate,
1518+
is_restatement=is_restatement,
15151519
**kwargs,
15161520
)
15171521

@@ -1533,6 +1537,7 @@ def _scd_type_2(
15331537
table_description: t.Optional[str] = None,
15341538
column_descriptions: t.Optional[t.Dict[str, str]] = None,
15351539
truncate: bool = False,
1540+
is_restatement: bool = False,
15361541
**kwargs: t.Any,
15371542
) -> None:
15381543
def remove_managed_columns(
@@ -1718,13 +1723,15 @@ def remove_managed_columns(
17181723
target_table
17191724
)
17201725

1721-
cleanup_ts = None
17221726
if truncate:
17231727
existing_rows_query = existing_rows_query.limit(0)
1724-
else:
1725-
# If truncate is false it is not the first insert
1726-
# Determine the cleanup timestamp for restatement or a regular incremental run
1727-
cleanup_ts = to_time_column(start, time_data_type, self.dialect, nullable=True)
1728+
1729+
# Only set cleanup_ts if is_restatement is True and truncate is False (this to enable full restatement)
1730+
cleanup_ts = (
1731+
to_time_column(start, time_data_type, self.dialect, nullable=True)
1732+
if is_restatement and not truncate
1733+
else None
1734+
)
17281735

17291736
with source_queries[0] as source_query:
17301737
prefixed_columns_to_types = []
@@ -1763,7 +1770,7 @@ def remove_managed_columns(
17631770
.with_(
17641771
"static",
17651772
existing_rows_query.where(valid_to_col.is_(exp.Null()).not_())
1766-
if truncate
1773+
if cleanup_ts is None
17671774
else existing_rows_query.where(
17681775
exp.and_(
17691776
valid_to_col.is_(exp.Null().not_()),
@@ -1775,7 +1782,7 @@ def remove_managed_columns(
17751782
.with_(
17761783
"latest",
17771784
existing_rows_query.where(valid_to_col.is_(exp.Null()))
1778-
if truncate
1785+
if cleanup_ts is None
17791786
else exp.select(
17801787
*(
17811788
to_time_column(

sqlmesh/core/engine_adapter/trino.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def _scd_type_2(
267267
table_description: t.Optional[str] = None,
268268
column_descriptions: t.Optional[t.Dict[str, str]] = None,
269269
truncate: bool = False,
270+
is_restatement: bool = False,
270271
**kwargs: t.Any,
271272
) -> None:
272273
if columns_to_types and self.current_catalog_type == "delta_lake":
@@ -289,6 +290,7 @@ def _scd_type_2(
289290
table_description,
290291
column_descriptions,
291292
truncate,
293+
is_restatement,
292294
**kwargs,
293295
)
294296

sqlmesh/core/plan/evaluator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234234
return
235235

236236
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
237+
# Convert model name restatements to snapshot ID restatements
238+
restatements_by_snapshot_id = {
239+
stage.all_snapshots[name].snapshot_id: interval
240+
for name, interval in plan.restatements.items()
241+
}
237242
errors, _ = scheduler.run_merged_intervals(
238243
merged_intervals=stage.snapshot_to_intervals,
239244
deployability_index=stage.deployability_index,
@@ -242,6 +247,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
242247
circuit_breaker=self._circuit_breaker,
243248
start=plan.start,
244249
end=plan.end,
250+
restatements=restatements_by_snapshot_id,
245251
)
246252
if errors:
247253
raise PlanError("Plan application failed.")

sqlmesh/core/scheduler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def evaluate(
161161
deployability_index: DeployabilityIndex,
162162
batch_index: int,
163163
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
164+
is_restatement: bool = False,
164165
**kwargs: t.Any,
165166
) -> t.List[AuditResult]:
166167
"""Evaluate a snapshot and add the processed interval to the state sync.
@@ -192,6 +193,7 @@ def evaluate(
192193
snapshots=snapshots,
193194
deployability_index=deployability_index,
194195
batch_index=batch_index,
196+
is_restatement=is_restatement,
195197
**kwargs,
196198
)
197199
audit_results = self._audit_snapshot(
@@ -371,6 +373,7 @@ def run_merged_intervals(
371373
end: t.Optional[TimeLike] = None,
372374
run_environment_statements: bool = False,
373375
audit_only: bool = False,
376+
restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
374377
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
375378
"""Runs precomputed batches of missing intervals.
376379
@@ -447,6 +450,10 @@ def evaluate_node(node: SchedulingUnit) -> None:
447450
execution_time=execution_time,
448451
)
449452
else:
453+
# Determine if this snapshot and interval is a restatement (for SCD type 2)
454+
is_restatement = (
455+
restatements is not None and snapshot.snapshot_id in restatements
456+
)
450457
audit_results = self.evaluate(
451458
snapshot=snapshot,
452459
environment_naming_info=environment_naming_info,
@@ -455,6 +462,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
455462
execution_time=execution_time,
456463
deployability_index=deployability_index,
457464
batch_index=batch_idx,
465+
is_restatement=is_restatement,
458466
)
459467

460468
evaluation_duration_ms = now_timestamp() - execution_start_ts
@@ -663,6 +671,7 @@ def _run_or_audit(
663671
end=end,
664672
run_environment_statements=run_environment_statements,
665673
audit_only=audit_only,
674+
restatements=remove_intervals,
666675
)
667676

668677
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS

sqlmesh/core/snapshot/evaluator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def evaluate(
140140
snapshots: t.Dict[str, Snapshot],
141141
deployability_index: t.Optional[DeployabilityIndex] = None,
142142
batch_index: int = 0,
143+
is_restatement: bool = False,
143144
**kwargs: t.Any,
144145
) -> t.Optional[str]:
145146
"""Renders the snapshot's model, executes it and stores the result in the snapshot's physical table.
@@ -165,6 +166,7 @@ def evaluate(
165166
snapshots,
166167
deployability_index=deployability_index,
167168
batch_index=batch_index,
169+
is_restatement=is_restatement,
168170
**kwargs,
169171
)
170172
if result is None or isinstance(result, str):
@@ -622,6 +624,7 @@ def _evaluate_snapshot(
622624
limit: t.Optional[int] = None,
623625
deployability_index: t.Optional[DeployabilityIndex] = None,
624626
batch_index: int = 0,
627+
is_restatement: bool = False,
625628
**kwargs: t.Any,
626629
) -> DF | str | None:
627630
"""Renders the snapshot's model and executes it. The return value depends on whether the limit was specified.
@@ -694,6 +697,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
694697
end=end,
695698
execution_time=execution_time,
696699
physical_properties=rendered_physical_properties,
700+
is_restatement=is_restatement,
697701
)
698702
else:
699703
logger.info(
@@ -715,6 +719,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
715719
end=end,
716720
execution_time=execution_time,
717721
physical_properties=rendered_physical_properties,
722+
is_restatement=is_restatement,
718723
)
719724

720725
with (
@@ -1833,6 +1838,7 @@ def insert(
18331838
column_descriptions=model.column_descriptions,
18341839
truncate=is_first_insert,
18351840
start=kwargs["start"],
1841+
is_restatement=kwargs.get("is_restatement", False),
18361842
)
18371843
elif isinstance(model.kind, SCDType2ByColumnKind):
18381844
self.adapter.scd_type_2_by_column(
@@ -1851,6 +1857,7 @@ def insert(
18511857
column_descriptions=model.column_descriptions,
18521858
truncate=is_first_insert,
18531859
start=kwargs["start"],
1860+
is_restatement=kwargs.get("is_restatement", False),
18541861
)
18551862
else:
18561863
raise SQLMeshError(

tests/core/engine_adapter/test_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
12231223
},
12241224
execution_time=datetime(2020, 1, 1, 0, 0, 0),
12251225
start=datetime(2020, 1, 1, 0, 0, 0),
1226+
is_restatement=True,
12261227
)
12271228

12281229
assert (
@@ -1422,6 +1423,7 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte
14221423
},
14231424
execution_time=datetime(2020, 1, 1, 0, 0, 0),
14241425
start=datetime(2020, 1, 1, 0, 0, 0),
1426+
is_restatement=True,
14251427
)
14261428

14271429
assert (
@@ -1610,6 +1612,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable):
16101612
},
16111613
execution_time=datetime(2020, 1, 1, 0, 0, 0),
16121614
start=datetime(2020, 1, 1, 0, 0, 0),
1615+
is_restatement=True,
16131616
)
16141617

16151618
assert (
@@ -1799,6 +1802,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable):
17991802
execution_time=datetime(2020, 1, 1, 0, 0, 0),
18001803
start=datetime(2020, 1, 1, 0, 0, 0),
18011804
extra_col_ignore="testing",
1805+
is_restatement=True,
18021806
)
18031807

18041808
assert (
@@ -1990,6 +1994,7 @@ def test_scd_type_2_by_column_composite_key(make_mocked_engine_adapter: t.Callab
19901994
},
19911995
execution_time=datetime(2020, 1, 1, 0, 0, 0),
19921996
start=datetime(2020, 1, 1, 0, 0, 0),
1997+
is_restatement=True,
19931998
)
19941999
assert (
19952000
parse_one(adapter.cursor.execute.call_args[0][0]).sql()
@@ -2352,6 +2357,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable)
23522357
},
23532358
execution_time=datetime(2020, 1, 1, 0, 0, 0),
23542359
start=datetime(2020, 1, 1, 0, 0, 0),
2360+
is_restatement=True,
23552361
)
23562362

23572363
assert (
@@ -2527,6 +2533,7 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap
25272533
},
25282534
execution_time=datetime(2020, 1, 1, 0, 0, 0),
25292535
start=datetime(2020, 1, 1, 0, 0, 0),
2536+
is_restatement=True,
25302537
)
25312538

25322539
assert (

0 commit comments

Comments
 (0)