From b4bcb2660304a2c2a64fc46ead261baa459d60b3 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Wed, 16 Jul 2025 18:05:14 +0300 Subject: [PATCH 1/2] Fix: Don't use SCD type 2 restatement logic in regular runs --- sqlmesh/core/engine_adapter/base.py | 21 ++-- sqlmesh/core/engine_adapter/trino.py | 2 + sqlmesh/core/plan/evaluator.py | 7 ++ sqlmesh/core/scheduler.py | 11 ++ sqlmesh/core/snapshot/evaluator.py | 7 ++ tests/core/engine_adapter/test_base.py | 7 ++ tests/core/test_integration.py | 163 +++++++++++++++++++++++++ tests/core/test_snapshot_evaluator.py | 2 + 8 files changed, 213 insertions(+), 7 deletions(-) diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 33ad4c398a..c615a3029d 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -1462,6 +1462,7 @@ def scd_type_2_by_time( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + is_restatement: bool = False, **kwargs: t.Any, ) -> None: self._scd_type_2( @@ -1478,6 +1479,7 @@ def scd_type_2_by_time( table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + is_restatement=is_restatement, **kwargs, ) @@ -1496,6 +1498,7 @@ def scd_type_2_by_column( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + is_restatement: bool = False, **kwargs: t.Any, ) -> None: self._scd_type_2( @@ -1512,6 +1515,7 @@ def scd_type_2_by_column( table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + is_restatement=is_restatement, **kwargs, ) @@ -1533,6 +1537,7 @@ def _scd_type_2( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + is_restatement: bool = False, **kwargs: t.Any, ) -> None: def remove_managed_columns( @@ -1718,13 +1723,15 @@ def remove_managed_columns( target_table ) - cleanup_ts = None if truncate: existing_rows_query = existing_rows_query.limit(0) - else: - # If truncate is false it is not the first insert - # Determine the cleanup timestamp for restatement or a regular incremental run - cleanup_ts = to_time_column(start, time_data_type, self.dialect, nullable=True) + + # Only set cleanup_ts if is_restatement is True and truncate is False (this to enable full restatement) + cleanup_ts = ( + to_time_column(start, time_data_type, self.dialect, nullable=True) + if is_restatement and not truncate + else None + ) with source_queries[0] as source_query: prefixed_columns_to_types = [] @@ -1763,7 +1770,7 @@ def remove_managed_columns( .with_( "static", existing_rows_query.where(valid_to_col.is_(exp.Null()).not_()) - if truncate + if cleanup_ts is None else existing_rows_query.where( exp.and_( valid_to_col.is_(exp.Null().not_()), @@ -1775,7 +1782,7 @@ def remove_managed_columns( .with_( "latest", existing_rows_query.where(valid_to_col.is_(exp.Null())) - if truncate + if cleanup_ts is None else exp.select( *( to_time_column( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 06d693e11c..7862bfca2d 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -267,6 +267,7 @@ def _scd_type_2( table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + is_restatement: bool = False, **kwargs: t.Any, ) -> None: if columns_to_types and self.current_catalog_type == "delta_lake": @@ -289,6 +290,7 @@ def _scd_type_2( table_description, column_descriptions, truncate, + is_restatement, **kwargs, ) diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 9488b9bc91..63719d8138 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -234,6 +234,12 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla return scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator) + # Convert model name restatements to snapshot ID restatements + restatements_by_snapshot_id = { + stage.all_snapshots[name].snapshot_id: interval + for name, interval in plan.restatements.items() + if name in stage.all_snapshots + } errors, _ = scheduler.run_merged_intervals( merged_intervals=stage.snapshot_to_intervals, deployability_index=stage.deployability_index, @@ -242,6 +248,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla circuit_breaker=self._circuit_breaker, start=plan.start, end=plan.end, + restatements=restatements_by_snapshot_id, ) if errors: raise PlanError("Plan application failed.") diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 4582b24485..57541a0690 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -161,6 +161,7 @@ def evaluate( deployability_index: DeployabilityIndex, batch_index: int, environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, + is_restatement: bool = False, **kwargs: t.Any, ) -> t.List[AuditResult]: """Evaluate a snapshot and add the processed interval to the state sync. @@ -192,6 +193,7 @@ def evaluate( snapshots=snapshots, deployability_index=deployability_index, batch_index=batch_index, + is_restatement=is_restatement, **kwargs, ) audit_results = self._audit_snapshot( @@ -371,6 +373,7 @@ def run_merged_intervals( end: t.Optional[TimeLike] = None, run_environment_statements: bool = False, audit_only: bool = False, + restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: """Runs precomputed batches of missing intervals. @@ -447,6 +450,12 @@ def evaluate_node(node: SchedulingUnit) -> None: execution_time=execution_time, ) else: + # Determine if this snapshot and interval is a restatement (for SCD type 2) + is_restatement = ( + restatements is not None + and snapshot.snapshot_id in restatements + and start >= restatements[snapshot.snapshot_id][0] + ) audit_results = self.evaluate( snapshot=snapshot, environment_naming_info=environment_naming_info, @@ -455,6 +464,7 @@ def evaluate_node(node: SchedulingUnit) -> None: execution_time=execution_time, deployability_index=deployability_index, batch_index=batch_idx, + is_restatement=is_restatement, ) evaluation_duration_ms = now_timestamp() - execution_start_ts @@ -663,6 +673,7 @@ def _run_or_audit( end=end, run_environment_statements=run_environment_statements, audit_only=audit_only, + restatements=remove_intervals, ) return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 993860b527..f8aa08a075 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -140,6 +140,7 @@ def evaluate( snapshots: t.Dict[str, Snapshot], deployability_index: t.Optional[DeployabilityIndex] = None, batch_index: int = 0, + is_restatement: bool = False, **kwargs: t.Any, ) -> t.Optional[str]: """Renders the snapshot's model, executes it and stores the result in the snapshot's physical table. @@ -165,6 +166,7 @@ def evaluate( snapshots, deployability_index=deployability_index, batch_index=batch_index, + is_restatement=is_restatement, **kwargs, ) if result is None or isinstance(result, str): @@ -622,6 +624,7 @@ def _evaluate_snapshot( limit: t.Optional[int] = None, deployability_index: t.Optional[DeployabilityIndex] = None, batch_index: int = 0, + is_restatement: bool = False, **kwargs: t.Any, ) -> DF | str | None: """Renders the snapshot's model and executes it. The return value depends on whether the limit was specified. @@ -694,6 +697,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: end=end, execution_time=execution_time, physical_properties=rendered_physical_properties, + is_restatement=is_restatement, ) else: logger.info( @@ -715,6 +719,7 @@ def apply(query_or_df: QueryOrDF, index: int = 0) -> None: end=end, execution_time=execution_time, physical_properties=rendered_physical_properties, + is_restatement=is_restatement, ) with ( @@ -1833,6 +1838,7 @@ def insert( column_descriptions=model.column_descriptions, truncate=is_first_insert, start=kwargs["start"], + is_restatement=kwargs.get("is_restatement", False), ) elif isinstance(model.kind, SCDType2ByColumnKind): self.adapter.scd_type_2_by_column( @@ -1851,6 +1857,7 @@ def insert( column_descriptions=model.column_descriptions, truncate=is_first_insert, start=kwargs["start"], + is_restatement=kwargs.get("is_restatement", False), ) else: raise SQLMeshError( diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index faf1386877..6c9d2ee132 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -1223,6 +1223,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): }, execution_time=datetime(2020, 1, 1, 0, 0, 0), start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, ) assert ( @@ -1422,6 +1423,7 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte }, execution_time=datetime(2020, 1, 1, 0, 0, 0), start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, ) assert ( @@ -1610,6 +1612,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): }, execution_time=datetime(2020, 1, 1, 0, 0, 0), start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, ) assert ( @@ -1799,6 +1802,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): execution_time=datetime(2020, 1, 1, 0, 0, 0), start=datetime(2020, 1, 1, 0, 0, 0), extra_col_ignore="testing", + is_restatement=True, ) assert ( @@ -1990,6 +1994,7 @@ def test_scd_type_2_by_column_composite_key(make_mocked_engine_adapter: t.Callab }, execution_time=datetime(2020, 1, 1, 0, 0, 0), start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, ) assert ( parse_one(adapter.cursor.execute.call_args[0][0]).sql() @@ -2352,6 +2357,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) }, execution_time=datetime(2020, 1, 1, 0, 0, 0), start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, ) assert ( @@ -2527,6 +2533,7 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap }, execution_time=datetime(2020, 1, 1, 0, 0, 0), start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, ) assert ( diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 7337f8d3f4..a3a54eb7b6 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -6916,3 +6916,166 @@ def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger): assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}" assert _correlation_id_in_sqls(correlation_id, mock_logger) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_scd_type_2_regular_run_with_offset(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + raw_employee_status = d.parse(""" + MODEL ( + name memory.hr_system.raw_employee_status, + kind FULL + ); + + SELECT + 1001 AS employee_id, + 'engineering' AS department, + 'EMEA' AS region, + '2023-01-08 15:00:00 UTC' AS last_modified; + """) + + employee_history = d.parse(""" + MODEL ( + name memory.hr_system.employee_history, + kind SCD_TYPE_2_BY_TIME ( + unique_key employee_id, + updated_at_name last_modified, + disable_restatement false + ), + owner hr_analytics, + cron '0 7 * * *', + grain employee_id, + description 'Historical tracking of employee status changes' + ); + + SELECT + employee_id::INT AS employee_id, + department::TEXT AS department, + region::TEXT AS region, + last_modified AS last_modified + FROM + memory.hr_system.raw_employee_status; + """) + + raw_employee_status_model = load_sql_based_model(raw_employee_status) + employee_history_model = load_sql_based_model(employee_history) + context.upsert_model(raw_employee_status_model) + context.upsert_model(employee_history_model) + + # Initial plan and apply + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + query = "SELECT employee_id, department, region, valid_from, valid_to FROM memory.hr_system.employee_history ORDER BY employee_id, valid_from" + initial_data = context.engine_adapter.fetchdf(query) + + assert len(initial_data) == 1 + assert initial_data["valid_to"].isna().all() + assert initial_data["department"].tolist() == ["engineering"] + assert initial_data["region"].tolist() == ["EMEA"] + + # Apply a future plan with source changes a few hours before the cron time of the SCD Type 2 model BUT on the same day + with time_machine.travel("2023-01-09 00:10:00 UTC"): + raw_employee_status_v2 = d.parse(""" + MODEL ( + name memory.hr_system.raw_employee_status, + kind FULL + ); + + SELECT + 1001 AS employee_id, + 'engineering' AS department, + 'AMER' AS region, + '2023-01-09 00:10:00 UTC' AS last_modified; + """) + raw_employee_status_v2_model = load_sql_based_model(raw_employee_status_v2) + context.upsert_model(raw_employee_status_v2_model) + context.plan( + auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full() + ) + + # The 7th hour of the day the run is kicked off for the SCD Type 2 model + with time_machine.travel("2023-01-09 07:00:01 UTC"): + context.run() + data_after_change = context.engine_adapter.fetchdf(query) + + # Validate the SCD2 records for employee 1001 + assert len(data_after_change) == 2 + assert data_after_change.iloc[0]["employee_id"] == 1001 + assert data_after_change.iloc[0]["department"] == "engineering" + assert data_after_change.iloc[0]["region"] == "EMEA" + assert str(data_after_change.iloc[0]["valid_from"]) == "1970-01-01 00:00:00" + assert str(data_after_change.iloc[0]["valid_to"]) == "2023-01-09 00:10:00" + assert data_after_change.iloc[1]["employee_id"] == 1001 + assert data_after_change.iloc[1]["department"] == "engineering" + assert data_after_change.iloc[1]["region"] == "AMER" + assert str(data_after_change.iloc[1]["valid_from"]) == "2023-01-09 00:10:00" + assert pd.isna(data_after_change.iloc[1]["valid_to"]) + + # Update source model again a bit later on the same day + raw_employee_status_v2 = d.parse(""" + MODEL ( + name memory.hr_system.raw_employee_status, + kind FULL + ); + + SELECT + 1001 AS employee_id, + 'sales' AS department, + 'ANZ' AS region, + '2023-01-09 07:26:00 UTC' AS last_modified; + """) + raw_employee_status_v2_model = load_sql_based_model(raw_employee_status_v2) + context.upsert_model(raw_employee_status_v2_model) + context.plan( + auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full() + ) + + # A day later the run is kicked off for the SCD Type 2 model again + with time_machine.travel("2023-01-10 07:00:00 UTC"): + context.run() + data_after_change = context.engine_adapter.fetchdf(query) + + # Validate the SCD2 history for employee 1001 after second change with the historical records intact + assert len(data_after_change) == 3 + assert data_after_change.iloc[0]["employee_id"] == 1001 + assert data_after_change.iloc[0]["department"] == "engineering" + assert data_after_change.iloc[0]["region"] == "EMEA" + assert str(data_after_change.iloc[0]["valid_from"]) == "1970-01-01 00:00:00" + assert str(data_after_change.iloc[0]["valid_to"]) == "2023-01-09 00:10:00" + assert data_after_change.iloc[1]["employee_id"] == 1001 + assert data_after_change.iloc[1]["department"] == "engineering" + assert data_after_change.iloc[1]["region"] == "AMER" + assert str(data_after_change.iloc[1]["valid_from"]) == "2023-01-09 00:10:00" + assert str(data_after_change.iloc[1]["valid_to"]) == "2023-01-09 07:26:00" + assert data_after_change.iloc[2]["employee_id"] == 1001 + assert data_after_change.iloc[2]["department"] == "sales" + assert data_after_change.iloc[2]["region"] == "ANZ" + assert str(data_after_change.iloc[2]["valid_from"]) == "2023-01-09 07:26:00" + assert pd.isna(data_after_change.iloc[2]["valid_to"]) + + # Now test restatement still works as expected by restating from 2023-01-09 00:10:00 (first change) + with time_machine.travel("2023-01-10 07:38:00 UTC"): + plan = context.plan_builder( + "prod", + skip_tests=True, + restate_models=["memory.hr_system.employee_history"], + start="2023-01-09 00:10:00", + ).build() + context.apply(plan) + restated_data = context.engine_adapter.fetchdf(query) + + # Validate the SCD2 history after restatement + assert len(restated_data) == 2 + assert restated_data.iloc[0]["employee_id"] == 1001 + assert restated_data.iloc[0]["department"] == "engineering" + assert restated_data.iloc[0]["region"] == "EMEA" + assert str(restated_data.iloc[0]["valid_from"]) == "1970-01-01 00:00:00" + assert str(restated_data.iloc[0]["valid_to"]) == "2023-01-09 07:26:00" + assert restated_data.iloc[1]["employee_id"] == 1001 + assert restated_data.iloc[1]["department"] == "sales" + assert restated_data.iloc[1]["region"] == "ANZ" + assert str(restated_data.iloc[1]["valid_from"]) == "2023-01-09 07:26:00" + assert pd.isna(restated_data.iloc[1]["valid_to"]) diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 93cef90daf..b01daf9e20 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -1973,6 +1973,7 @@ def test_insert_into_scd_type_2_by_time( column_descriptions={}, updated_at_as_valid_from=False, truncate=truncate, + is_restatement=False, start="2020-01-01", ) adapter_mock.columns.assert_called_once_with(snapshot.table_name()) @@ -2146,6 +2147,7 @@ def test_insert_into_scd_type_2_by_column( table_description=None, column_descriptions={}, truncate=truncate, + is_restatement=False, start="2020-01-01", ) adapter_mock.columns.assert_called_once_with(snapshot.table_name()) From 2d9c32b78290d732986d04685d8ea371508e1571 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Wed, 16 Jul 2025 18:38:16 +0300 Subject: [PATCH 2/2] address comments --- sqlmesh/core/plan/evaluator.py | 1 - sqlmesh/core/scheduler.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 63719d8138..bb779fffe9 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -238,7 +238,6 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla restatements_by_snapshot_id = { stage.all_snapshots[name].snapshot_id: interval for name, interval in plan.restatements.items() - if name in stage.all_snapshots } errors, _ = scheduler.run_merged_intervals( merged_intervals=stage.snapshot_to_intervals, diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 57541a0690..7177efe927 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -452,9 +452,7 @@ def evaluate_node(node: SchedulingUnit) -> None: else: # Determine if this snapshot and interval is a restatement (for SCD type 2) is_restatement = ( - restatements is not None - and snapshot.snapshot_id in restatements - and start >= restatements[snapshot.snapshot_id][0] + restatements is not None and snapshot.snapshot_id in restatements ) audit_results = self.evaluate( snapshot=snapshot,