Skip to content

Commit 74a63ec

Browse files
committed
fix: dbt prevent all cycles from tests
1 parent 16a032f commit 74a63ec

File tree

4 files changed

+205
-78
lines changed

4 files changed

+205
-78
lines changed

sqlmesh/dbt/basemodel.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -268,33 +268,6 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
268268
and all(source in context.sources for source in test.dependencies.sources)
269269
]
270270

271-
def fix_circular_test_refs(self, context: DbtContext) -> None:
272-
"""
273-
Checks for direct circular references between two models and moves the test to the downstream
274-
model if found. This addresses the most common circular reference - relationship tests in both
275-
directions. In the future, we may want to increase coverage by checking for indirect circular references.
276-
277-
Args:
278-
context: The dbt context this model resides within.
279-
280-
Returns:
281-
None
282-
"""
283-
for test in self.tests.copy():
284-
for ref in test.dependencies.refs:
285-
if ref == self.name or ref in self.dependencies.refs:
286-
continue
287-
model = context.refs[ref]
288-
if (
289-
self.name in model.dependencies.refs
290-
or self.name in model.tests_ref_source_dependencies.refs
291-
):
292-
logger.info(
293-
f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference."
294-
)
295-
model.tests.append(test)
296-
self.tests.remove(test)
297-
298271
@property
299272
def sqlmesh_config_fields(self) -> t.Set[str]:
300273
return {"description", "owner", "stamp", "storage_format"}
@@ -314,7 +287,6 @@ def sqlmesh_model_kwargs(
314287
) -> t.Dict[str, t.Any]:
315288
"""Get common sqlmesh model parameters"""
316289
self.remove_tests_with_invalid_refs(context)
317-
self.fix_circular_test_refs(context)
318290

319291
dependencies = self.dependencies.copy()
320292
if dependencies.has_dynamic_var_names:

sqlmesh/dbt/manifest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,12 @@ def _load_models_and_seeds(self) -> None:
338338
continue
339339

340340
macro_references = _macro_references(self._manifest, node)
341-
tests = (
341+
all_tests = (
342342
self._tests_by_owner[node.name]
343343
+ self._tests_by_owner[f"{node.package_name}.{node.name}"]
344344
)
345+
# Only include non-standalone tests (tests that don't reference other models)
346+
tests = [test for test in all_tests if not test.is_standalone]
345347
node_config = _node_base_config(node)
346348

347349
node_name = node.name

sqlmesh/dbt/test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,15 @@ def _lowercase_name(cls, v: str) -> str:
108108

109109
@property
110110
def is_standalone(self) -> bool:
111-
return not self.model_name
111+
# A test is standalone if:
112+
# 1. It has no model_name (already standalone), OR
113+
# 2. It references other models besides its own model
114+
if not self.model_name:
115+
return True
116+
117+
# Check if test has references to other models
118+
other_refs = {ref for ref in self.dependencies.refs if ref != self.model_name}
119+
return bool(other_refs)
112120

113121
@property
114122
def sqlmesh_config_fields(self) -> t.Set[str]:

tests/dbt/test_model.py

Lines changed: 193 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -63,57 +63,202 @@ def _create_empty_project() -> t.Tuple[Path, Path]:
6363
return _create_empty_project
6464

6565

66-
def test_model_test_circular_references() -> None:
67-
upstream_model = ModelConfig(name="upstream")
68-
downstream_model = ModelConfig(name="downstream", dependencies=Dependencies(refs={"upstream"}))
69-
context = DbtContext(_refs={"upstream": upstream_model, "downstream": downstream_model})
70-
71-
# Test and downstream model references
72-
downstream_test = TestConfig(
73-
name="downstream_with_upstream",
74-
sql="",
75-
dependencies=Dependencies(refs={"upstream", "downstream"}),
66+
def test_test_config_is_standalone_behavior() -> None:
67+
"""Test that TestConfig.is_standalone correctly identifies tests with cross-model references"""
68+
69+
# Test with no model_name (should be standalone)
70+
standalone_test = TestConfig(
71+
name="standalone_test",
72+
sql="SELECT 1",
73+
model_name=None,
74+
dependencies=Dependencies(refs={"some_model"}),
7675
)
77-
upstream_test = TestConfig(
78-
name="upstream_with_downstream",
79-
sql="",
80-
dependencies=Dependencies(refs={"upstream", "downstream"}),
76+
assert standalone_test.is_standalone is True
77+
78+
# Test with only self-reference (should not be standalone)
79+
self_ref_test = TestConfig(
80+
name="self_ref_test",
81+
sql="SELECT * FROM {{ this }}",
82+
model_name="my_model",
83+
dependencies=Dependencies(refs={"my_model"}),
8184
)
85+
assert self_ref_test.is_standalone is False
86+
87+
# Test with no references (should not be standalone)
88+
no_ref_test = TestConfig(
89+
name="no_ref_test",
90+
sql="SELECT 1",
91+
model_name="my_model",
92+
dependencies=Dependencies(),
93+
)
94+
assert no_ref_test.is_standalone is False
95+
96+
# Test with references to other models (should be standalone)
97+
cross_ref_test = TestConfig(
98+
name="cross_ref_test",
99+
sql="SELECT * FROM {{ ref('other_model') }}",
100+
model_name="my_model",
101+
dependencies=Dependencies(refs={"my_model", "other_model"}),
102+
)
103+
assert cross_ref_test.is_standalone is True
104+
105+
# Test with only references to other models, no self-reference (should be standalone)
106+
other_only_test = TestConfig(
107+
name="other_only_test",
108+
sql="SELECT * FROM {{ ref('other_model') }}",
109+
model_name="my_model",
110+
dependencies=Dependencies(refs={"other_model"}),
111+
)
112+
assert other_only_test.is_standalone is True
113+
114+
115+
def test_test_to_sqlmesh_creates_correct_audit_type(
116+
dbt_dummy_postgres_config: PostgresConfig,
117+
) -> None:
118+
"""Test that TestConfig.to_sqlmesh creates the correct audit type based on is_standalone"""
119+
from sqlmesh.core.audit.definition import StandaloneAudit, ModelAudit
120+
121+
# Set up models in context
122+
my_model = ModelConfig(
123+
name="my_model", sql="SELECT 1", schema="test_schema", database="test_db", alias="my_model"
124+
)
125+
other_model = ModelConfig(
126+
name="other_model",
127+
sql="SELECT 2",
128+
schema="test_schema",
129+
database="test_db",
130+
alias="other_model",
131+
)
132+
context = DbtContext(
133+
_refs={"my_model": my_model, "other_model": other_model},
134+
_target=dbt_dummy_postgres_config,
135+
)
136+
137+
# Test with only self-reference (should create ModelAudit)
138+
self_ref_test = TestConfig(
139+
name="self_ref_test",
140+
sql="SELECT * FROM {{ this }}",
141+
model_name="my_model",
142+
dependencies=Dependencies(refs={"my_model"}),
143+
)
144+
audit = self_ref_test.to_sqlmesh(context)
145+
assert isinstance(audit, ModelAudit)
146+
assert audit.name == "self_ref_test"
147+
148+
# Test with references to other models (should create StandaloneAudit)
149+
cross_ref_test = TestConfig(
150+
name="cross_ref_test",
151+
sql="SELECT * FROM {{ ref('other_model') }}",
152+
model_name="my_model",
153+
dependencies=Dependencies(refs={"my_model", "other_model"}),
154+
)
155+
audit = cross_ref_test.to_sqlmesh(context)
156+
assert isinstance(audit, StandaloneAudit)
157+
assert audit.name == "cross_ref_test"
158+
159+
# Test with no model_name (should create StandaloneAudit)
160+
standalone_test = TestConfig(
161+
name="standalone_test",
162+
sql="SELECT 1",
163+
model_name=None,
164+
dependencies=Dependencies(),
165+
)
166+
audit = standalone_test.to_sqlmesh(context)
167+
assert isinstance(audit, StandaloneAudit)
168+
assert audit.name == "standalone_test"
169+
170+
171+
@pytest.mark.slow
172+
def test_manifest_filters_standalone_tests_from_models(
173+
tmp_path: Path, create_empty_project
174+
) -> None:
175+
"""Integration test that verifies models only contain non-standalone tests after manifest loading."""
176+
yaml = YAML()
177+
project_dir, model_dir = create_empty_project()
178+
179+
# Create two models
180+
model1_contents = "SELECT 1 as id"
181+
model1_file = model_dir / "model1.sql"
182+
with open(model1_file, "w", encoding="utf-8") as f:
183+
f.write(model1_contents)
184+
185+
model2_contents = "SELECT 2 as id"
186+
model2_file = model_dir / "model2.sql"
187+
with open(model2_file, "w", encoding="utf-8") as f:
188+
f.write(model2_contents)
189+
190+
# Create schema with both standalone and non-standalone tests
191+
schema_yaml = {
192+
"version": 2,
193+
"models": [
194+
{
195+
"name": "model1",
196+
"columns": [
197+
{
198+
"name": "id",
199+
"tests": [
200+
"not_null", # Non-standalone test - only references model1
201+
{
202+
"relationships": { # Standalone test - references model2
203+
"to": "ref('model2')",
204+
"field": "id",
205+
}
206+
},
207+
],
208+
}
209+
],
210+
},
211+
{
212+
"name": "model2",
213+
"columns": [
214+
{"name": "id", "tests": ["not_null"]} # Non-standalone test
215+
],
216+
},
217+
],
218+
}
219+
220+
schema_file = model_dir / "schema.yml"
221+
with open(schema_file, "w", encoding="utf-8") as f:
222+
yaml.dump(schema_yaml, f)
223+
224+
# Load the project through SQLMesh Context
225+
from sqlmesh.core.context import Context
226+
227+
context = Context(paths=project_dir)
82228

83-
# No circular reference
84-
downstream_model.tests = [downstream_test]
85-
downstream_model.fix_circular_test_refs(context)
86-
assert upstream_model.tests == []
87-
assert downstream_model.tests == [downstream_test]
88-
89-
# Upstream model reference in downstream model
90-
downstream_model.tests = []
91-
upstream_model.tests = [upstream_test]
92-
upstream_model.fix_circular_test_refs(context)
93-
assert upstream_model.tests == []
94-
assert downstream_model.tests == [upstream_test]
95-
96-
upstream_model.tests = [upstream_test]
97-
downstream_model.tests = [downstream_test]
98-
upstream_model.fix_circular_test_refs(context)
99-
assert upstream_model.tests == []
100-
assert downstream_model.tests == [downstream_test, upstream_test]
101-
102-
downstream_model.fix_circular_test_refs(context)
103-
assert upstream_model.tests == []
104-
assert downstream_model.tests == [downstream_test, upstream_test]
105-
106-
# Test only references
107-
upstream_model.tests = [upstream_test]
108-
downstream_model.tests = [downstream_test]
109-
downstream_model.dependencies = Dependencies()
110-
upstream_model.fix_circular_test_refs(context)
111-
assert upstream_model.tests == []
112-
assert downstream_model.tests == [downstream_test, upstream_test]
113-
114-
downstream_model.fix_circular_test_refs(context)
115-
assert upstream_model.tests == []
116-
assert downstream_model.tests == [downstream_test, upstream_test]
229+
model1_snapshot = context.snapshots['"local"."main"."model1"']
230+
model2_snapshot = context.snapshots['"local"."main"."model2"']
231+
232+
# Verify model1 only has non-standalone test in its audits
233+
# Should only have "not_null" test, not the "relationships" test
234+
# audits is a list of tuples (audit_name, audit_config)
235+
model1_audit_names = [audit[0] for audit in model1_snapshot.model.audits]
236+
assert len(model1_audit_names) == 1
237+
assert model1_audit_names[0] == "not_null_model1_id"
238+
239+
# Verify model2 has its non-standalone test
240+
model2_audit_names = [audit[0] for audit in model2_snapshot.model.audits]
241+
assert len(model2_audit_names) == 1
242+
assert model2_audit_names[0] == "not_null_model2_id"
243+
244+
# Verify the standalone test (relationships) exists as a StandaloneAudit
245+
all_non_standalone_audits = [name for name, audit in context._audits.items()]
246+
assert sorted(all_non_standalone_audits) == [
247+
"not_null_model1_id",
248+
"not_null_model2_id",
249+
]
250+
251+
standalone_audits = [name for name, audit in context._standalone_audits.items()]
252+
assert len(standalone_audits) == 1
253+
assert standalone_audits[0] == "relationships_model1_id__id__ref_model2_"
254+
255+
plan_builder = context.plan_builder()
256+
dag = plan_builder._build_dag()
257+
assert [x.name for x in dag.sorted] == [
258+
'"local"."main"."model1"',
259+
'"local"."main"."model2"',
260+
"relationships_model1_id__id__ref_model2_",
261+
]
117262

118263

119264
@pytest.mark.slow

0 commit comments

Comments
 (0)