Skip to content

Commit 41bbafa

Browse files
committed
Fix!: Avoid using rendered query when computing the data hash
1 parent 254e9e1 commit 41bbafa

File tree

10 files changed

+110
-40
lines changed

10 files changed

+110
-40
lines changed

sqlmesh/core/context_diff.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def directly_modified(self, name: str) -> bool:
435435
return False
436436

437437
current, previous = self.modified_snapshots[name]
438-
return current.fingerprint.data_hash != previous.fingerprint.data_hash
438+
return current.is_directly_modified(previous)
439439

440440
def indirectly_modified(self, name: str) -> bool:
441441
"""Returns whether or not a node was indirectly modified in this context.
@@ -451,10 +451,7 @@ def indirectly_modified(self, name: str) -> bool:
451451
return False
452452

453453
current, previous = self.modified_snapshots[name]
454-
return (
455-
current.fingerprint.data_hash == previous.fingerprint.data_hash
456-
and current.fingerprint.parent_data_hash != previous.fingerprint.parent_data_hash
457-
)
454+
return current.is_indirectly_modified(previous)
458455

459456
def metadata_updated(self, name: str) -> bool:
460457
"""Returns whether or not the given node's metadata has been updated.
@@ -470,7 +467,7 @@ def metadata_updated(self, name: str) -> bool:
470467
return False
471468

472469
current, previous = self.modified_snapshots[name]
473-
return current.fingerprint.metadata_hash != previous.fingerprint.metadata_hash
470+
return current.is_metadata_updated(previous)
474471

475472
def text_diff(self, name: str) -> str:
476473
"""Finds the difference of a node between the current and remote environment.

sqlmesh/core/model/definition.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262

6363
if t.TYPE_CHECKING:
6464
from sqlglot.dialects.dialect import DialectType
65+
from sqlmesh.core.node import _Node
6566
from sqlmesh.core._typing import Self, TableName, SessionProperties
6667
from sqlmesh.core.context import ExecutionContext
6768
from sqlmesh.core.engine_adapter import EngineAdapter
@@ -1278,6 +1279,7 @@ class SqlModel(_Model):
12781279
source_type: t.Literal["sql"] = "sql"
12791280

12801281
_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None
1282+
_is_metadata_only_change_cache: t.Dict[int, bool] = {}
12811283

12821284
def __getstate__(self) -> t.Dict[t.Any, t.Any]:
12831285
state = super().__getstate__()
@@ -1500,6 +1502,27 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]:
15001502

15011503
return False
15021504

1505+
def is_metadata_only_change(self, previous: _Node) -> bool:
1506+
if self._is_metadata_only_change_cache.get(id(previous), None) is not None:
1507+
return self._is_metadata_only_change_cache[id(previous)]
1508+
1509+
if (
1510+
not isinstance(previous, SqlModel)
1511+
or self.metadata_hash == previous.metadata_hash
1512+
or self._data_hash_values_no_query != previous._data_hash_values_no_query
1513+
):
1514+
is_metadata_change = False
1515+
else:
1516+
# If the rendered queries are the same, then this is a metadata only change
1517+
this_rendered_query = self.render_query()
1518+
previous_rendered_query = previous.render_query()
1519+
is_metadata_change = (
1520+
this_rendered_query is not None and this_rendered_query == previous_rendered_query
1521+
)
1522+
1523+
self._is_metadata_only_change_cache[id(previous)] = is_metadata_change
1524+
return is_metadata_change
1525+
15031526
@cached_property
15041527
def _query_renderer(self) -> QueryRenderer:
15051528
no_quote_identifiers = self.kind.is_view and self.dialect in ("trino", "spark")
@@ -1519,17 +1542,22 @@ def _query_renderer(self) -> QueryRenderer:
15191542
)
15201543

15211544
@property
1522-
def _data_hash_values(self) -> t.List[str]:
1523-
data = super()._data_hash_values
1545+
def _data_hash_values_no_query(self) -> t.List[str]:
1546+
return [
1547+
*super()._data_hash_values,
1548+
*self.jinja_macros.data_hash_values,
1549+
]
15241550

1525-
query = self.render_query() or self.query
1526-
data.append(gen(query))
1527-
data.extend(self.jinja_macros.data_hash_values)
1528-
return data
1551+
@property
1552+
def _data_hash_values(self) -> t.List[str]:
1553+
return [
1554+
*self._data_hash_values_no_query,
1555+
gen(self.query, comments=False),
1556+
]
15291557

15301558
@property
15311559
def _additional_metadata(self) -> t.List[str]:
1532-
return [*super()._additional_metadata, gen(self.query)]
1560+
return [*super()._additional_metadata, gen(self.query, comments=True)]
15331561

15341562
@property
15351563
def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]:

sqlmesh/core/node.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -307,16 +307,6 @@ def batch_concurrency(self) -> t.Optional[int]:
307307
"""The maximal number of batches that can run concurrently for a backfill."""
308308
return None
309309

310-
@property
311-
def data_hash(self) -> str:
312-
"""
313-
Computes the data hash for the node.
314-
315-
Returns:
316-
The data hash for the node.
317-
"""
318-
raise NotImplementedError
319-
320310
@property
321311
def interval_unit(self) -> IntervalUnit:
322312
"""Returns the interval unit using which data intervals are computed for this node."""
@@ -332,6 +322,16 @@ def depends_on(self) -> t.Set[str]:
332322
def fqn(self) -> str:
333323
return self.name
334324

325+
@property
326+
def data_hash(self) -> str:
327+
"""
328+
Computes the data hash for the node.
329+
330+
Returns:
331+
The data hash for the node.
332+
"""
333+
raise NotImplementedError
334+
335335
@property
336336
def metadata_hash(self) -> str:
337337
"""
@@ -342,6 +342,30 @@ def metadata_hash(self) -> str:
342342
"""
343343
raise NotImplementedError
344344

345+
def is_metadata_only_change(self, previous: _Node) -> bool:
346+
"""Determines if this node is a metadata only change in relation to the `previous` node.
347+
348+
Args:
349+
previous: The previous node to compare against.
350+
351+
Returns:
352+
True if this node is a metadata only change, False otherwise.
353+
"""
354+
return self.data_hash == previous.data_hash and self.metadata_hash != previous.metadata_hash
355+
356+
def is_data_change(self, previous: _Node) -> bool:
357+
"""Determines if this node is a data change in relation to the `previous` node.
358+
359+
Args:
360+
previous: The previous node to compare against.
361+
362+
Returns:
363+
True if this node is a data change, False otherwise.
364+
"""
365+
return (
366+
self.data_hash != previous.data_hash or self.metadata_hash != previous.metadata_hash
367+
) and not self.is_metadata_only_change(previous)
368+
345369
def croniter(self, value: TimeLike) -> CroniterCache:
346370
if self._croniter is None:
347371
self._croniter = CroniterCache(self.cron, value, tz=self.cron_tz)

sqlmesh/core/snapshot/categorizer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ def categorize_change(
4747
if type(new_model) != type(old_model):
4848
return default_category
4949

50-
if new.fingerprint.data_hash == old.fingerprint.data_hash:
51-
if new.fingerprint.metadata_hash == old.fingerprint.metadata_hash:
52-
raise SQLMeshError(
53-
f"{new} is unmodified or indirectly modified and should not be categorized"
54-
)
50+
if new.fingerprint == old.fingerprint:
51+
raise SQLMeshError(
52+
f"{new} is unmodified or indirectly modified and should not be categorized"
53+
)
54+
55+
if not new.is_directly_modified(old):
5556
if new.fingerprint.parent_data_hash == old.fingerprint.parent_data_hash:
5657
return SnapshotChangeCategory.NON_BREAKING
5758
return None

sqlmesh/core/snapshot/definition.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,21 @@ def apply_pending_restatement_intervals(self) -> None:
12301230
)
12311231
self.intervals = remove_interval(self.intervals, *pending_restatement_interval)
12321232

1233+
def is_directly_modified(self, other: Snapshot) -> bool:
1234+
"""Returns whether or not this snapshot is directly modified in relation to the other snapshot."""
1235+
return self.node.is_data_change(other.node)
1236+
1237+
def is_indirectly_modified(self, other: Snapshot) -> bool:
1238+
"""Returns whether or not this snapshot is indirectly modified in relation to the other snapshot."""
1239+
return (
1240+
self.fingerprint.parent_data_hash != other.fingerprint.parent_data_hash
1241+
and not self.node.is_data_change(other.node)
1242+
)
1243+
1244+
def is_metadata_updated(self, other: Snapshot) -> bool:
1245+
"""Returns whether or not this snapshot contains metadata changes in relation to the other snapshot."""
1246+
return self.fingerprint.metadata_hash != other.fingerprint.metadata_hash
1247+
12331248
@property
12341249
def physical_schema(self) -> str:
12351250
if self.physical_schema_ is not None:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Use the unrendered query when computing the model fingerprint."""
2+
3+
4+
def migrate(state_sync, **kwargs): # type: ignore
5+
pass

tests/core/test_integration.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4681,12 +4681,12 @@ def test_plan_repairs_unrenderable_snapshot_state(
46814681
f"name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'",
46824682
)
46834683

4684+
context.clear_caches()
4685+
target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[
4686+
target_snapshot.snapshot_id
4687+
]
4688+
46844689
with pytest.raises(Exception):
4685-
context_copy = context.copy()
4686-
context_copy.clear_caches()
4687-
target_snapshot_in_state = context_copy.state_sync.get_snapshots(
4688-
[target_snapshot.snapshot_id]
4689-
)[target_snapshot.snapshot_id]
46904690
target_snapshot_in_state.model.render_query_or_raise()
46914691

46924692
# Repair the snapshot by creating a new version of it
@@ -4695,11 +4695,11 @@ def test_plan_repairs_unrenderable_snapshot_state(
46954695

46964696
plan_builder = context.plan_builder("prod", forward_only=forward_only)
46974697
plan = plan_builder.build()
4698-
assert plan.directly_modified == {target_snapshot.snapshot_id}
46994698
if not forward_only:
47004699
assert target_snapshot.snapshot_id in {i.snapshot_id for i in plan.missing_intervals}
4701-
plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING)
4702-
plan = plan_builder.build()
4700+
assert plan.directly_modified == {target_snapshot.snapshot_id}
4701+
plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING)
4702+
plan = plan_builder.build()
47034703

47044704
context.apply(plan)
47054705

tests/core/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5735,7 +5735,7 @@ def test_default_catalog_sql(assert_exp_eq):
57355735
The system is not designed to actually support having an engine that doesn't support default catalog
57365736
to start supporting it or the reverse of that. If that did happen then bugs would occur.
57375737
"""
5738-
HASH_WITH_CATALOG = "1269513823"
5738+
HASH_WITH_CATALOG = "3443912775"
57395739

57405740
# Test setting default catalog doesn't change hash if it matches existing logic
57415741
expressions = d.parse(

tests/core/test_selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def test_select_change_schema(mocker: MockerFixture, make_snapshot):
301301
selector = Selector(state_reader_mock, local_models)
302302

303303
selected = selector.select_models(["db.parent"], env_name)
304-
assert selected[local_child.fqn].data_hash != child.data_hash
304+
assert selected[local_child.fqn].render_query() != child.render_query()
305305

306306
_assert_models_equal(
307307
selected,

tests/core/test_snapshot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,7 @@ def test_fingerprint(model: Model, parent_model: Model):
913913
fingerprint = fingerprint_from_node(model, nodes={})
914914

915915
original_fingerprint = SnapshotFingerprint(
916-
data_hash="3301649319",
916+
data_hash="1698409777",
917917
metadata_hash="3575333731",
918918
)
919919

@@ -1013,7 +1013,7 @@ def test_fingerprint_jinja_macros(model: Model):
10131013
}
10141014
)
10151015
original_fingerprint = SnapshotFingerprint(
1016-
data_hash="2908339239",
1016+
data_hash="343517722",
10171017
metadata_hash="3575333731",
10181018
)
10191019

0 commit comments

Comments
 (0)