Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions sqlmesh/core/plan/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def _categorize_snapshots(
# If the model kind changes mark as breaking
if snapshot.is_model and snapshot.name in self._context_diff.modified_snapshots:
_, old = self._context_diff.modified_snapshots[snapshot.name]
if old.model.kind.name != snapshot.model.kind.name:
if _is_breaking_kind_change(old, snapshot):
category = SnapshotChangeCategory.BREAKING

snapshot.categorize_as(category)
Expand Down Expand Up @@ -765,8 +765,8 @@ def _is_forward_only_change(self, s_id: SnapshotId) -> bool:
snapshot = self._context_diff.snapshots[s_id]
if snapshot.name in self._context_diff.modified_snapshots:
_, old = self._context_diff.modified_snapshots[snapshot.name]
# If the model kind has changed, then we should not consider this to be a forward-only change.
if snapshot.is_model and old.model.kind.name != snapshot.model.kind.name:
# If the model kind has changed in a breaking way, then we can't consider this to be a forward-only change.
if snapshot.is_model and _is_breaking_kind_change(old, snapshot):
return False
return (
snapshot.is_model
Expand Down Expand Up @@ -882,3 +882,16 @@ def _modified_and_added_snapshots(self) -> t.List[Snapshot]:
if snapshot.name in self._context_diff.modified_snapshots
or snapshot.snapshot_id in self._context_diff.added
]


def _is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
if old.model.kind.name == new.model.kind.name:
# If the kind hasn't changed, then it's not a breaking change
return False
if not old.is_incremental or not new.is_incremental:
# If either is not incremental, then it's a breaking change
return True
if old.model.partitioned_by == new.model.partitioned_by:
# If the partitioning hasn't changed, then it's not a breaking change
return False
return True
18 changes: 7 additions & 11 deletions sqlmesh/dbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ def model_kind(self, context: DbtContext) -> ModelKind:
if field_val is not None:
incremental_by_kind_kwargs[field] = field_val

disable_restatement = self.disable_restatement
if disable_restatement is None:
disable_restatement = (
not self.full_refresh if self.full_refresh is not None else False
)
incremental_kind_kwargs["disable_restatement"] = disable_restatement

if self.time_column:
strategy = self.incremental_strategy or target.default_incremental_strategy(
IncrementalByTimeRangeKind
Expand All @@ -270,20 +277,11 @@ def model_kind(self, context: DbtContext) -> ModelKind:

return IncrementalByTimeRangeKind(
time_column=self.time_column,
disable_restatement=(
self.disable_restatement if self.disable_restatement is not None else False
),
auto_restatement_intervals=self.auto_restatement_intervals,
**incremental_kind_kwargs,
**incremental_by_kind_kwargs,
)

disable_restatement = self.disable_restatement
if disable_restatement is None:
disable_restatement = (
not self.full_refresh if self.full_refresh is not None else False
)

if self.unique_key:
strategy = self.incremental_strategy or target.default_incremental_strategy(
IncrementalByUniqueKeyKind
Expand All @@ -309,7 +307,6 @@ def model_kind(self, context: DbtContext) -> ModelKind:

return IncrementalByUniqueKeyKind(
unique_key=self.unique_key,
disable_restatement=disable_restatement,
**incremental_kind_kwargs,
**incremental_by_kind_kwargs,
)
Expand All @@ -319,7 +316,6 @@ def model_kind(self, context: DbtContext) -> ModelKind:
)
return IncrementalUnmanagedKind(
insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES,
disable_restatement=disable_restatement,
**incremental_kind_kwargs,
)
if materialization == Materialization.EPHEMERAL:
Expand Down
58 changes: 57 additions & 1 deletion tests/core/test_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
from tests.core.test_table_diff import create_test_console
import time_machine
from pytest_mock.plugin import MockerFixture
from sqlglot import parse_one
from sqlglot import parse_one, exp

from sqlmesh.core import dialect as d
from sqlmesh.core.context import Context
from sqlmesh.core.context_diff import ContextDiff
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentStatements
from sqlmesh.core.model import (
ExternalModel,
FullKind,
IncrementalByTimeRangeKind,
IncrementalUnmanagedKind,
SeedKind,
SeedModel,
SqlModel,
Expand Down Expand Up @@ -1724,6 +1726,60 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix
assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING


@pytest.mark.parametrize(
"partitioned_by, expected_change_category",
[
([], SnapshotChangeCategory.BREAKING),
([d.parse_one("ds")], SnapshotChangeCategory.FORWARD_ONLY),
],
)
def test_forward_only_models_model_kind_changed_to_incremental_by_time_range(
make_snapshot,
partitioned_by: t.List[exp.Expression],
expected_change_category: SnapshotChangeCategory,
):
snapshot = make_snapshot(
SqlModel(
name="a",
query=parse_one("select 1, ds"),
kind=IncrementalUnmanagedKind(),
partitioned_by=partitioned_by,
)
)
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
updated_snapshot = make_snapshot(
SqlModel(
name="a",
query=parse_one("select 3, ds"),
kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True),
)
)
updated_snapshot.previous_versions = snapshot.all_versions

context_diff = ContextDiff(
environment="test_environment",
is_new_environment=True,
is_unfinalized_environment=False,
normalize_environment_name=True,
create_from="prod",
create_from_env_exists=True,
added=set(),
removed_snapshots={},
modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)},
snapshots={updated_snapshot.snapshot_id: updated_snapshot},
new_snapshots={updated_snapshot.snapshot_id: updated_snapshot},
previous_plan_id=None,
previously_promoted_snapshot_ids=set(),
previous_finalized_snapshots=None,
previous_gateway_managed_virtual_layer=False,
gateway_managed_virtual_layer=False,
environment_statements=[],
)

PlanBuilder(context_diff, is_dev=True).build()
assert updated_snapshot.change_category == expected_change_category


def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFixture):
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds")))
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)
Expand Down
18 changes: 18 additions & 0 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,24 @@ def test_model_kind():
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False
)

assert ModelConfig(
materialized=Materialization.INCREMENTAL,
time_column="foo",
incremental_strategy="merge",
full_refresh=True,
).model_kind(context) == IncrementalByTimeRangeKind(
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False
)

assert ModelConfig(
materialized=Materialization.INCREMENTAL,
time_column="foo",
incremental_strategy="merge",
full_refresh=False,
).model_kind(context) == IncrementalByTimeRangeKind(
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=True
)

assert ModelConfig(
materialized=Materialization.INCREMENTAL,
time_column="foo",
Expand Down