Skip to content

Commit aedd0f2

Browse files
committed
Fix: Suppor forward-only changes of model kinds under certain circumstances
1 parent 089ba46 commit aedd0f2

File tree

4 files changed

+98
-15
lines changed

4 files changed

+98
-15
lines changed

sqlmesh/core/plan/builder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def _categorize_snapshots(
601601
# If the model kind changes mark as breaking
602602
if snapshot.is_model and snapshot.name in self._context_diff.modified_snapshots:
603603
_, old = self._context_diff.modified_snapshots[snapshot.name]
604-
if old.model.kind.name != snapshot.model.kind.name:
604+
if _is_breaking_kind_change(old, snapshot):
605605
category = SnapshotChangeCategory.BREAKING
606606

607607
snapshot.categorize_as(category)
@@ -765,8 +765,8 @@ def _is_forward_only_change(self, s_id: SnapshotId) -> bool:
765765
snapshot = self._context_diff.snapshots[s_id]
766766
if snapshot.name in self._context_diff.modified_snapshots:
767767
_, old = self._context_diff.modified_snapshots[snapshot.name]
768-
# If the model kind has changed, then we should not consider this to be a forward-only change.
769-
if snapshot.is_model and old.model.kind.name != snapshot.model.kind.name:
768+
# If the model kind has changed in a breaking way, then we can't consider this to be a forward-only change.
769+
if snapshot.is_model and _is_breaking_kind_change(old, snapshot):
770770
return False
771771
return (
772772
snapshot.is_model
@@ -882,3 +882,16 @@ def _modified_and_added_snapshots(self) -> t.List[Snapshot]:
882882
if snapshot.name in self._context_diff.modified_snapshots
883883
or snapshot.snapshot_id in self._context_diff.added
884884
]
885+
886+
887+
def _is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool:
888+
if old.model.kind.name == new.model.kind.name:
889+
# If the kind hasn't changed, then it's not a breaking change
890+
return False
891+
if not old.is_incremental or not new.is_incremental:
892+
# If either is not incremental, then it's a breaking change
893+
return True
894+
if old.model.partitioned_by == new.model.partitioned_by:
895+
# If the partitioning hasn't changed, then it's not a breaking change
896+
return False
897+
return True

sqlmesh/dbt/model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,13 @@ def model_kind(self, context: DbtContext) -> ModelKind:
257257
if field_val is not None:
258258
incremental_by_kind_kwargs[field] = field_val
259259

260+
disable_restatement = self.disable_restatement
261+
if disable_restatement is None:
262+
disable_restatement = (
263+
not self.full_refresh if self.full_refresh is not None else False
264+
)
265+
incremental_kind_kwargs["disable_restatement"] = disable_restatement
266+
260267
if self.time_column:
261268
strategy = self.incremental_strategy or target.default_incremental_strategy(
262269
IncrementalByTimeRangeKind
@@ -270,20 +277,11 @@ def model_kind(self, context: DbtContext) -> ModelKind:
270277

271278
return IncrementalByTimeRangeKind(
272279
time_column=self.time_column,
273-
disable_restatement=(
274-
self.disable_restatement if self.disable_restatement is not None else False
275-
),
276280
auto_restatement_intervals=self.auto_restatement_intervals,
277281
**incremental_kind_kwargs,
278282
**incremental_by_kind_kwargs,
279283
)
280284

281-
disable_restatement = self.disable_restatement
282-
if disable_restatement is None:
283-
disable_restatement = (
284-
not self.full_refresh if self.full_refresh is not None else False
285-
)
286-
287285
if self.unique_key:
288286
strategy = self.incremental_strategy or target.default_incremental_strategy(
289287
IncrementalByUniqueKeyKind
@@ -309,7 +307,6 @@ def model_kind(self, context: DbtContext) -> ModelKind:
309307

310308
return IncrementalByUniqueKeyKind(
311309
unique_key=self.unique_key,
312-
disable_restatement=disable_restatement,
313310
**incremental_kind_kwargs,
314311
**incremental_by_kind_kwargs,
315312
)
@@ -319,7 +316,6 @@ def model_kind(self, context: DbtContext) -> ModelKind:
319316
)
320317
return IncrementalUnmanagedKind(
321318
insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES,
322-
disable_restatement=disable_restatement,
323319
**incremental_kind_kwargs,
324320
)
325321
if materialization == Materialization.EPHEMERAL:

tests/core/test_plan.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@
88
from tests.core.test_table_diff import create_test_console
99
import time_machine
1010
from pytest_mock.plugin import MockerFixture
11-
from sqlglot import parse_one
11+
from sqlglot import parse_one, exp
1212

13+
from sqlmesh.core import dialect as d
1314
from sqlmesh.core.context import Context
1415
from sqlmesh.core.context_diff import ContextDiff
1516
from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentStatements
1617
from sqlmesh.core.model import (
1718
ExternalModel,
1819
FullKind,
1920
IncrementalByTimeRangeKind,
21+
IncrementalUnmanagedKind,
2022
SeedKind,
2123
SeedModel,
2224
SqlModel,
@@ -1724,6 +1726,60 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix
17241726
assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING
17251727

17261728

1729+
@pytest.mark.parametrize(
1730+
"partitioned_by, expected_change_category",
1731+
[
1732+
([], SnapshotChangeCategory.BREAKING),
1733+
([d.parse_one("ds")], SnapshotChangeCategory.FORWARD_ONLY),
1734+
],
1735+
)
1736+
def test_forward_only_models_model_kind_changed_to_incremental_by_time_range(
1737+
make_snapshot,
1738+
partitioned_by: t.List[exp.Expression],
1739+
expected_change_category: SnapshotChangeCategory,
1740+
):
1741+
snapshot = make_snapshot(
1742+
SqlModel(
1743+
name="a",
1744+
query=parse_one("select 1, ds"),
1745+
kind=IncrementalUnmanagedKind(),
1746+
partitioned_by=partitioned_by,
1747+
)
1748+
)
1749+
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
1750+
updated_snapshot = make_snapshot(
1751+
SqlModel(
1752+
name="a",
1753+
query=parse_one("select 3, ds"),
1754+
kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True),
1755+
)
1756+
)
1757+
updated_snapshot.previous_versions = snapshot.all_versions
1758+
1759+
context_diff = ContextDiff(
1760+
environment="test_environment",
1761+
is_new_environment=True,
1762+
is_unfinalized_environment=False,
1763+
normalize_environment_name=True,
1764+
create_from="prod",
1765+
create_from_env_exists=True,
1766+
added=set(),
1767+
removed_snapshots={},
1768+
modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)},
1769+
snapshots={updated_snapshot.snapshot_id: updated_snapshot},
1770+
new_snapshots={updated_snapshot.snapshot_id: updated_snapshot},
1771+
previous_plan_id=None,
1772+
previously_promoted_snapshot_ids=set(),
1773+
previous_finalized_snapshots=None,
1774+
previous_gateway_managed_virtual_layer=False,
1775+
gateway_managed_virtual_layer=False,
1776+
environment_statements=[],
1777+
)
1778+
1779+
PlanBuilder(context_diff, is_dev=True).build()
1780+
assert updated_snapshot.change_category == expected_change_category
1781+
1782+
17271783
def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFixture):
17281784
snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds")))
17291785
snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING)

tests/dbt/test_transformation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,24 @@ def test_model_kind():
244244
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False
245245
)
246246

247+
assert ModelConfig(
248+
materialized=Materialization.INCREMENTAL,
249+
time_column="foo",
250+
incremental_strategy="merge",
251+
full_refresh=True,
252+
).model_kind(context) == IncrementalByTimeRangeKind(
253+
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False
254+
)
255+
256+
assert ModelConfig(
257+
materialized=Materialization.INCREMENTAL,
258+
time_column="foo",
259+
incremental_strategy="merge",
260+
full_refresh=False,
261+
).model_kind(context) == IncrementalByTimeRangeKind(
262+
time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=True
263+
)
264+
247265
assert ModelConfig(
248266
materialized=Materialization.INCREMENTAL,
249267
time_column="foo",

0 commit comments

Comments
 (0)