Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions sqlmesh/core/engine_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,7 @@ def scd_type_2_by_time(
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
truncate: bool = False,
is_restatement: bool = False,
**kwargs: t.Any,
) -> None:
self._scd_type_2(
Expand All @@ -1478,6 +1479,7 @@ def scd_type_2_by_time(
table_description=table_description,
column_descriptions=column_descriptions,
truncate=truncate,
is_restatement=is_restatement,
**kwargs,
)

Expand All @@ -1496,6 +1498,7 @@ def scd_type_2_by_column(
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
truncate: bool = False,
is_restatement: bool = False,
**kwargs: t.Any,
) -> None:
self._scd_type_2(
Expand All @@ -1512,6 +1515,7 @@ def scd_type_2_by_column(
table_description=table_description,
column_descriptions=column_descriptions,
truncate=truncate,
is_restatement=is_restatement,
**kwargs,
)

Expand All @@ -1533,6 +1537,7 @@ def _scd_type_2(
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
truncate: bool = False,
is_restatement: bool = False,
**kwargs: t.Any,
) -> None:
def remove_managed_columns(
Expand Down Expand Up @@ -1718,13 +1723,15 @@ def remove_managed_columns(
target_table
)

cleanup_ts = None
if truncate:
existing_rows_query = existing_rows_query.limit(0)
else:
# If truncate is false it is not the first insert
# Determine the cleanup timestamp for restatement or a regular incremental run
cleanup_ts = to_time_column(start, time_data_type, self.dialect, nullable=True)

# Only set cleanup_ts if is_restatement is True and truncate is False (this to enable full restatement)
cleanup_ts = (
to_time_column(start, time_data_type, self.dialect, nullable=True)
if is_restatement and not truncate
else None
)

with source_queries[0] as source_query:
prefixed_columns_to_types = []
Expand Down Expand Up @@ -1763,7 +1770,7 @@ def remove_managed_columns(
.with_(
"static",
existing_rows_query.where(valid_to_col.is_(exp.Null()).not_())
if truncate
if cleanup_ts is None
else existing_rows_query.where(
exp.and_(
valid_to_col.is_(exp.Null().not_()),
Expand All @@ -1775,7 +1782,7 @@ def remove_managed_columns(
.with_(
"latest",
existing_rows_query.where(valid_to_col.is_(exp.Null()))
if truncate
if cleanup_ts is None
else exp.select(
*(
to_time_column(
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/engine_adapter/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def _scd_type_2(
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
truncate: bool = False,
is_restatement: bool = False,
**kwargs: t.Any,
) -> None:
if columns_to_types and self.current_catalog_type == "delta_lake":
Expand All @@ -289,6 +290,7 @@ def _scd_type_2(
table_description,
column_descriptions,
truncate,
is_restatement,
**kwargs,
)

Expand Down
6 changes: 6 additions & 0 deletions sqlmesh/core/plan/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
return

scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
# Convert model name restatements to snapshot ID restatements
restatements_by_snapshot_id = {
stage.all_snapshots[name].snapshot_id: interval
for name, interval in plan.restatements.items()
}
errors, _ = scheduler.run_merged_intervals(
merged_intervals=stage.snapshot_to_intervals,
deployability_index=stage.deployability_index,
Expand All @@ -242,6 +247,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
circuit_breaker=self._circuit_breaker,
start=plan.start,
end=plan.end,
restatements=restatements_by_snapshot_id,
)
if errors:
raise PlanError("Plan application failed.")
Expand Down
9 changes: 9 additions & 0 deletions sqlmesh/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def evaluate(
deployability_index: DeployabilityIndex,
batch_index: int,
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
is_restatement: bool = False,
**kwargs: t.Any,
) -> t.List[AuditResult]:
"""Evaluate a snapshot and add the processed interval to the state sync.
Expand Down Expand Up @@ -192,6 +193,7 @@ def evaluate(
snapshots=snapshots,
deployability_index=deployability_index,
batch_index=batch_index,
is_restatement=is_restatement,
**kwargs,
)
audit_results = self._audit_snapshot(
Expand Down Expand Up @@ -371,6 +373,7 @@ def run_merged_intervals(
end: t.Optional[TimeLike] = None,
run_environment_statements: bool = False,
audit_only: bool = False,
restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
"""Runs precomputed batches of missing intervals.

Expand Down Expand Up @@ -447,6 +450,10 @@ def evaluate_node(node: SchedulingUnit) -> None:
execution_time=execution_time,
)
else:
# Determine if this snapshot and interval is a restatement (for SCD type 2)
is_restatement = (
restatements is not None and snapshot.snapshot_id in restatements
)
audit_results = self.evaluate(
snapshot=snapshot,
environment_naming_info=environment_naming_info,
Expand All @@ -455,6 +462,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
execution_time=execution_time,
deployability_index=deployability_index,
batch_index=batch_idx,
is_restatement=is_restatement,
)

evaluation_duration_ms = now_timestamp() - execution_start_ts
Expand Down Expand Up @@ -663,6 +671,7 @@ def _run_or_audit(
end=end,
run_environment_statements=run_environment_statements,
audit_only=audit_only,
restatements=remove_intervals,
)

return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS
Expand Down
7 changes: 7 additions & 0 deletions sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def evaluate(
snapshots: t.Dict[str, Snapshot],
deployability_index: t.Optional[DeployabilityIndex] = None,
batch_index: int = 0,
is_restatement: bool = False,
**kwargs: t.Any,
) -> t.Optional[str]:
"""Renders the snapshot's model, executes it and stores the result in the snapshot's physical table.
Expand All @@ -165,6 +166,7 @@ def evaluate(
snapshots,
deployability_index=deployability_index,
batch_index=batch_index,
is_restatement=is_restatement,
**kwargs,
)
if result is None or isinstance(result, str):
Expand Down Expand Up @@ -622,6 +624,7 @@ def _evaluate_snapshot(
limit: t.Optional[int] = None,
deployability_index: t.Optional[DeployabilityIndex] = None,
batch_index: int = 0,
is_restatement: bool = False,
**kwargs: t.Any,
) -> DF | str | None:
"""Renders the snapshot's model and executes it. The return value depends on whether the limit was specified.
Expand Down Expand Up @@ -694,6 +697,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
end=end,
execution_time=execution_time,
physical_properties=rendered_physical_properties,
is_restatement=is_restatement,
)
else:
logger.info(
Expand All @@ -715,6 +719,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None:
end=end,
execution_time=execution_time,
physical_properties=rendered_physical_properties,
is_restatement=is_restatement,
)

with (
Expand Down Expand Up @@ -1833,6 +1838,7 @@ def insert(
column_descriptions=model.column_descriptions,
truncate=is_first_insert,
start=kwargs["start"],
is_restatement=kwargs.get("is_restatement", False),
)
elif isinstance(model.kind, SCDType2ByColumnKind):
self.adapter.scd_type_2_by_column(
Expand All @@ -1851,6 +1857,7 @@ def insert(
column_descriptions=model.column_descriptions,
truncate=is_first_insert,
start=kwargs["start"],
is_restatement=kwargs.get("is_restatement", False),
)
else:
raise SQLMeshError(
Expand Down
7 changes: 7 additions & 0 deletions tests/core/engine_adapter/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
},
execution_time=datetime(2020, 1, 1, 0, 0, 0),
start=datetime(2020, 1, 1, 0, 0, 0),
is_restatement=True,
)

assert (
Expand Down Expand Up @@ -1422,6 +1423,7 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte
},
execution_time=datetime(2020, 1, 1, 0, 0, 0),
start=datetime(2020, 1, 1, 0, 0, 0),
is_restatement=True,
)

assert (
Expand Down Expand Up @@ -1610,6 +1612,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable):
},
execution_time=datetime(2020, 1, 1, 0, 0, 0),
start=datetime(2020, 1, 1, 0, 0, 0),
is_restatement=True,
)

assert (
Expand Down Expand Up @@ -1799,6 +1802,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable):
execution_time=datetime(2020, 1, 1, 0, 0, 0),
start=datetime(2020, 1, 1, 0, 0, 0),
extra_col_ignore="testing",
is_restatement=True,
)

assert (
Expand Down Expand Up @@ -1990,6 +1994,7 @@ def test_scd_type_2_by_column_composite_key(make_mocked_engine_adapter: t.Callab
},
execution_time=datetime(2020, 1, 1, 0, 0, 0),
start=datetime(2020, 1, 1, 0, 0, 0),
is_restatement=True,
)
assert (
parse_one(adapter.cursor.execute.call_args[0][0]).sql()
Expand Down Expand Up @@ -2352,6 +2357,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable)
},
execution_time=datetime(2020, 1, 1, 0, 0, 0),
start=datetime(2020, 1, 1, 0, 0, 0),
is_restatement=True,
)

assert (
Expand Down Expand Up @@ -2527,6 +2533,7 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap
},
execution_time=datetime(2020, 1, 1, 0, 0, 0),
start=datetime(2020, 1, 1, 0, 0, 0),
is_restatement=True,
)

assert (
Expand Down
Loading