diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index 567920997e..6f3c7f0805 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -601,7 +601,7 @@ def _categorize_snapshots( # If the model kind changes mark as breaking if snapshot.is_model and snapshot.name in self._context_diff.modified_snapshots: _, old = self._context_diff.modified_snapshots[snapshot.name] - if old.model.kind.name != snapshot.model.kind.name: + if _is_breaking_kind_change(old, snapshot): category = SnapshotChangeCategory.BREAKING snapshot.categorize_as(category) @@ -765,8 +765,8 @@ def _is_forward_only_change(self, s_id: SnapshotId) -> bool: snapshot = self._context_diff.snapshots[s_id] if snapshot.name in self._context_diff.modified_snapshots: _, old = self._context_diff.modified_snapshots[snapshot.name] - # If the model kind has changed, then we should not consider this to be a forward-only change. - if snapshot.is_model and old.model.kind.name != snapshot.model.kind.name: + # If the model kind has changed in a breaking way, then we can't consider this to be a forward-only change. + if snapshot.is_model and _is_breaking_kind_change(old, snapshot): return False return ( snapshot.is_model @@ -882,3 +882,16 @@ def _modified_and_added_snapshots(self) -> t.List[Snapshot]: if snapshot.name in self._context_diff.modified_snapshots or snapshot.snapshot_id in self._context_diff.added ] + + +def _is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool: + if old.model.kind.name == new.model.kind.name: + # If the kind hasn't changed, then it's not a breaking change + return False + if not old.is_incremental or not new.is_incremental: + # If either is not incremental, then it's a breaking change + return True + if old.model.partitioned_by == new.model.partitioned_by: + # If the partitioning hasn't changed, then it's not a breaking change + return False + return True diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 4cbca09aee..e34e66e119 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -257,6 +257,13 @@ def model_kind(self, context: DbtContext) -> ModelKind: if field_val is not None: incremental_by_kind_kwargs[field] = field_val + disable_restatement = self.disable_restatement + if disable_restatement is None: + disable_restatement = ( + not self.full_refresh if self.full_refresh is not None else False + ) + incremental_kind_kwargs["disable_restatement"] = disable_restatement + if self.time_column: strategy = self.incremental_strategy or target.default_incremental_strategy( IncrementalByTimeRangeKind @@ -270,20 +277,11 @@ def model_kind(self, context: DbtContext) -> ModelKind: return IncrementalByTimeRangeKind( time_column=self.time_column, - disable_restatement=( - self.disable_restatement if self.disable_restatement is not None else False - ), auto_restatement_intervals=self.auto_restatement_intervals, **incremental_kind_kwargs, **incremental_by_kind_kwargs, ) - disable_restatement = self.disable_restatement - if disable_restatement is None: - disable_restatement = ( - not self.full_refresh if self.full_refresh is not None else False - ) - if self.unique_key: strategy = self.incremental_strategy or target.default_incremental_strategy( IncrementalByUniqueKeyKind @@ -309,7 +307,6 @@ def model_kind(self, context: DbtContext) -> ModelKind: return IncrementalByUniqueKeyKind( unique_key=self.unique_key, - disable_restatement=disable_restatement, **incremental_kind_kwargs, **incremental_by_kind_kwargs, ) @@ -319,7 +316,6 @@ def model_kind(self, context: DbtContext) -> ModelKind: ) return IncrementalUnmanagedKind( insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES, - disable_restatement=disable_restatement, **incremental_kind_kwargs, ) if materialization == Materialization.EPHEMERAL: diff --git a/tests/core/test_plan.py b/tests/core/test_plan.py index 765a45f7c8..efaeba8623 100644 --- a/tests/core/test_plan.py +++ b/tests/core/test_plan.py @@ -8,8 +8,9 @@ from tests.core.test_table_diff import create_test_console import time_machine from pytest_mock.plugin import MockerFixture -from sqlglot import parse_one +from sqlglot import parse_one, exp +from sqlmesh.core import dialect as d from sqlmesh.core.context import Context from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentStatements @@ -17,6 +18,7 @@ ExternalModel, FullKind, IncrementalByTimeRangeKind, + IncrementalUnmanagedKind, SeedKind, SeedModel, SqlModel, @@ -1724,6 +1726,60 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING +@pytest.mark.parametrize( + "partitioned_by, expected_change_category", + [ + ([], SnapshotChangeCategory.BREAKING), + ([d.parse_one("ds")], SnapshotChangeCategory.FORWARD_ONLY), + ], +) +def test_forward_only_models_model_kind_changed_to_incremental_by_time_range( + make_snapshot, + partitioned_by: t.List[exp.Expression], + expected_change_category: SnapshotChangeCategory, +): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + kind=IncrementalUnmanagedKind(), + partitioned_by=partitioned_by, + ) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + updated_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 3, ds"), + kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True), + ) + ) + updated_snapshot.previous_versions = snapshot.all_versions + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, + snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + new_snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + PlanBuilder(context_diff, is_dev=True).build() + assert updated_snapshot.change_category == expected_change_category + + def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFixture): snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds"))) snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 17b8a6f313..e483e45ae1 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -244,6 +244,24 @@ def test_model_kind(): time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False ) + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + time_column="foo", + incremental_strategy="merge", + full_refresh=True, + ).model_kind(context) == IncrementalByTimeRangeKind( + time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False + ) + + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + time_column="foo", + incremental_strategy="merge", + full_refresh=False, + ).model_kind(context) == IncrementalByTimeRangeKind( + time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=True + ) + assert ModelConfig( materialized=Materialization.INCREMENTAL, time_column="foo",