Skip to content

Commit a74d964

Browse files
committed
fix: dbt handle indirect cycles
1 parent a6945cb commit a74d964

File tree

4 files changed

+315
-27
lines changed

4 files changed

+315
-27
lines changed

sqlmesh/dbt/basemodel.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sqlmesh.dbt.test import TestConfig
3131
from sqlmesh.dbt.util import DBT_VERSION
3232
from sqlmesh.utils import AttributeDict
33+
from sqlmesh.utils.dag import find_path_with_dfs
3334
from sqlmesh.utils.errors import ConfigError
3435
from sqlmesh.utils.pydantic import field_validator
3536

@@ -270,9 +271,10 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
270271

271272
def fix_circular_test_refs(self, context: DbtContext) -> None:
272273
"""
273-
Checks for direct circular references between two models and moves the test to the downstream
274-
model if found. This addresses the most common circular reference - relationship tests in both
275-
directions. In the future, we may want to increase coverage by checking for indirect circular references.
274+
Checks for circular references between models and moves tests to break cycles.
275+
This handles both direct circular references (A -> B -> A) and indirect circular
276+
references (A -> B -> C -> A). Tests are moved to the model that appears latest
277+
in the dependency chain to ensure the cycle is broken.
276278
277279
Args:
278280
context: The dbt context this model resides within.
@@ -284,16 +286,91 @@ def fix_circular_test_refs(self, context: DbtContext) -> None:
284286
for ref in test.dependencies.refs:
285287
if ref == self.name or ref in self.dependencies.refs:
286288
continue
287-
model = context.refs[ref]
288-
if (
289-
self.name in model.dependencies.refs
290-
or self.name in model.tests_ref_source_dependencies.refs
291-
):
289+
290+
# Check if moving this test would create or maintain a cycle
291+
cycle_path = self._find_circular_path(ref, context, set())
292+
if cycle_path:
293+
# Find the model in the cycle that should receive the test
294+
# We want to move to the model that appears latest in the dependency chain
295+
target_model_name = self._select_target_model_for_test(cycle_path, context)
296+
target_model = context.refs[target_model_name]
297+
292298
logger.info(
293-
f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference."
299+
f"Moving test '{test.name}' from model '{self.name}' to '{target_model_name}' "
300+
f"to avoid circular reference through path: {' -> '.join(cycle_path)}"
294301
)
295-
model.tests.append(test)
302+
target_model.tests.append(test)
296303
self.tests.remove(test)
304+
break
305+
306+
def _find_circular_path(
307+
self, ref: str, context: DbtContext, visited: t.Set[str]
308+
) -> t.Optional[t.List[str]]:
309+
"""
310+
Find if there's a circular dependency path from ref back to this model.
311+
312+
Args:
313+
ref: The model name to start searching from
314+
context: The dbt context
315+
visited: Set of model names already visited in this path
316+
317+
Returns:
318+
List of model names forming the circular path, or None if no cycle exists
319+
"""
320+
# Build a graph of all models and their dependencies from the context
321+
graph: t.Dict[str, t.Set[str]] = {}
322+
323+
def build_graph_from_node(node_name: str, current_visited: t.Set[str]) -> None:
324+
if node_name in current_visited or node_name in graph:
325+
return
326+
current_visited.add(node_name)
327+
328+
model = context.refs[node_name]
329+
# Include both direct model dependencies and test dependencies
330+
all_refs = model.dependencies.refs | model.tests_ref_source_dependencies.refs
331+
graph[node_name] = all_refs.copy()
332+
333+
# Recursively build graph for dependencies
334+
for dep in all_refs:
335+
build_graph_from_node(dep, current_visited)
336+
337+
# Build the graph starting from the ref, including visited nodes to avoid infinite recursion
338+
build_graph_from_node(ref, visited.copy())
339+
340+
# Add self.name to the graph if it's not already there
341+
if self.name not in graph:
342+
graph[self.name] = set()
343+
344+
# Use the shared DFS function to find path from ref to self.name
345+
return find_path_with_dfs(graph, start_node=ref, target_node=self.name)
346+
347+
def _select_target_model_for_test(self, cycle_path: t.List[str], context: DbtContext) -> str:
348+
"""
349+
Select which model in the cycle should receive the test.
350+
We select the model that has the most downstream dependencies in the cycle
351+
352+
Args:
353+
cycle_path: List of model names in the circular dependency path
354+
context: The dbt context
355+
356+
Returns:
357+
Name of the model that should receive the test
358+
"""
359+
# Count how many other models in the cycle each model depends on
360+
dependency_counts = {}
361+
362+
for model_name in cycle_path:
363+
model = context.refs[model_name]
364+
all_refs = model.dependencies.refs | model.tests_ref_source_dependencies.refs
365+
count = len([ref for ref in all_refs if ref in cycle_path])
366+
dependency_counts[model_name] = count
367+
368+
# Return the model with the fewest dependencies within the cycle
369+
# (i.e., the most downstream model in the cycle)
370+
if dependency_counts:
371+
return min(dependency_counts, key=dependency_counts.get) # type: ignore
372+
# Fallback to the last model in the path
373+
return cycle_path[-1]
297374

298375
@property
299376
def sqlmesh_config_fields(self) -> t.Set[str]:

sqlmesh/utils/dag.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,82 @@
1515
T = t.TypeVar("T", bound=t.Hashable)
1616

1717

18+
def find_path_with_dfs(
19+
graph: t.Dict[T, t.Set[T]],
20+
start_node: t.Optional[T] = None,
21+
target_node: t.Optional[T] = None,
22+
) -> t.Optional[t.List[T]]:
23+
"""
24+
Find a path in a graph using depth-first search.
25+
26+
This function can be used for two main purposes:
27+
1. Find any cycle in a cyclic subgraph (when target_node is None)
28+
2. Find a specific path from start_node to target_node
29+
30+
Args:
31+
graph: Dictionary mapping nodes to their dependencies/neighbors
32+
start_node: Optional specific node to start the search from
33+
target_node: Optional target node to search for. If None, finds any cycle
34+
35+
Returns:
36+
List of nodes forming the path, or None if no path/cycle found
37+
"""
38+
if not graph:
39+
return None
40+
41+
visited: t.Set[T] = set()
42+
rec_stack: t.Set[T] = set()
43+
path: t.List[T] = []
44+
45+
def dfs(node: T) -> t.Optional[t.List[T]]:
46+
if target_node is None:
47+
# Cycle detection mode: look for any node in recursion stack
48+
if node in rec_stack:
49+
cycle_start = path.index(node)
50+
return path[cycle_start:] + [node]
51+
else:
52+
# Target search mode: look for specific target
53+
if node == target_node:
54+
return [node]
55+
56+
if node in visited:
57+
return None
58+
59+
visited.add(node)
60+
rec_stack.add(node)
61+
path.append(node)
62+
63+
# Follow edges to neighbors
64+
for neighbor in graph.get(node, set()):
65+
if neighbor in graph: # Only follow edges to nodes in our subgraph
66+
result = dfs(neighbor)
67+
if result:
68+
if target_node is None:
69+
# Cycle detection: return the cycle as-is
70+
return result
71+
# Target search: prepend current node to path
72+
return [node] + result
73+
74+
rec_stack.remove(node)
75+
path.pop()
76+
return None
77+
78+
# Determine which nodes to try as starting points
79+
start_nodes = [start_node] if start_node is not None else list(graph.keys())
80+
81+
for node in start_nodes:
82+
if node not in visited and node in graph:
83+
result = dfs(node)
84+
if result:
85+
if target_node is None:
86+
# Cycle detection: remove duplicate node at end
87+
return result[:-1] if len(result) > 1 and result[0] == result[-1] else result
88+
# Target search: return path as-is
89+
return result
90+
91+
return None
92+
93+
1894
class DAG(t.Generic[T]):
1995
def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
2096
self._dag: t.Dict[T, t.Set[T]] = {}
@@ -99,6 +175,17 @@ def upstream(self, node: T) -> t.Set[T]:
99175

100176
return self._upstream[node]
101177

178+
def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]:
179+
"""Find the exact cycle path using DFS when a cycle is detected.
180+
181+
Args:
182+
nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies
183+
184+
Returns:
185+
List of nodes forming the cycle path, or None if no cycle found
186+
"""
187+
return find_path_with_dfs(nodes_in_cycle)
188+
102189
@property
103190
def roots(self) -> t.Set[T]:
104191
"""Returns all nodes in the graph without any upstream dependencies."""
@@ -125,23 +212,28 @@ def sorted(self) -> t.List[T]:
125212
next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps}
126213

127214
if not next_nodes:
128-
# Sort cycle candidates to make the order deterministic
129-
cycle_candidates_msg = (
130-
"\nPossible candidates to check for circular references: "
131-
+ ", ".join(str(node) for node in sorted(cycle_candidates))
132-
)
215+
# A cycle was detected - find the exact cycle path
216+
cycle_path = self._find_cycle_path(unprocessed_nodes)
133217

134-
if last_processed_nodes:
135-
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
136-
str(node) for node in last_processed_nodes
137-
)
218+
last_processed_msg = ""
219+
if cycle_path:
220+
cycle_msg = f"\nCycle: {' -> '.join(str(node) for node in cycle_path)} -> {cycle_path[0]}"
138221
else:
139-
last_processed_msg = ""
222+
# Fallback message in case a cycle can't be found
223+
cycle_candidates_msg = (
224+
"\nPossible candidates to check for circular references: "
225+
+ ", ".join(str(node) for node in sorted(cycle_candidates))
226+
)
227+
cycle_msg = cycle_candidates_msg
228+
if last_processed_nodes:
229+
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
230+
str(node) for node in last_processed_nodes
231+
)
140232

141233
raise SQLMeshError(
142234
"Detected a cycle in the DAG. "
143235
"Please make sure there are no circular references between nodes."
144-
f"{last_processed_msg}{cycle_candidates_msg}"
236+
f"{last_processed_msg}{cycle_msg}"
145237
)
146238

147239
for node in next_nodes:

tests/dbt/test_model.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,126 @@ def test_model_test_circular_references() -> None:
115115
assert downstream_model.tests == [downstream_test, upstream_test]
116116

117117

118+
def test_model_test_indirect_circular_references() -> None:
119+
"""Test detection and resolution of indirect circular references through test dependencies"""
120+
model_a = ModelConfig(name="model_a") # No dependencies
121+
model_b = ModelConfig(
122+
name="model_b", dependencies=Dependencies(refs={"model_a"})
123+
) # B depends on A
124+
model_c = ModelConfig(
125+
name="model_c", dependencies=Dependencies(refs={"model_b"})
126+
) # C depends on B
127+
128+
context = DbtContext(_refs={"model_a": model_a, "model_b": model_b, "model_c": model_c})
129+
130+
# Test on model_a that references model_c (creates indirect cycle through test dependencies)
131+
# The cycle would be: model_a (via test) -> model_c -> model_b -> model_a
132+
test_a_refs_c = TestConfig(
133+
name="test_a_refs_c",
134+
sql="",
135+
dependencies=Dependencies(refs={"model_a", "model_c"}), # Test references both A and C
136+
)
137+
138+
# Place tests that would create indirect cycles when combined with model dependencies
139+
model_a.tests = [test_a_refs_c]
140+
assert model_b.tests == []
141+
assert model_c.tests == []
142+
143+
# Fix circular references on model_a
144+
model_a.fix_circular_test_refs(context)
145+
# The test should be moved from model_a to break the indirect cycle down to model c
146+
assert model_a.tests == []
147+
assert test_a_refs_c in model_c.tests
148+
149+
150+
def test_model_test_complex_indirect_circular_references() -> None:
151+
"""Test detection and resolution of more complex indirect circular references through test dependencies"""
152+
# Create models with a longer linear dependency chain (no cycles in models themselves)
153+
# A -> B -> C -> D (B depends on A, C depends on B, D depends on C)
154+
model_a = ModelConfig(name="model_a") # No dependencies
155+
model_b = ModelConfig(
156+
name="model_b", dependencies=Dependencies(refs={"model_a"})
157+
) # B depends on A
158+
model_c = ModelConfig(
159+
name="model_c", dependencies=Dependencies(refs={"model_b"})
160+
) # C depends on B
161+
model_d = ModelConfig(
162+
name="model_d", dependencies=Dependencies(refs={"model_c"})
163+
) # D depends on C
164+
165+
context = DbtContext(
166+
_refs={"model_a": model_a, "model_b": model_b, "model_c": model_c, "model_d": model_d}
167+
)
168+
169+
# Test on model_a that references model_d (creates long indirect cycle through test dependencies)
170+
# The cycle would be: model_a (via test) -> model_d -> model_c -> model_b -> model_a
171+
test_a_refs_d = TestConfig(
172+
name="test_a_refs_d",
173+
sql="",
174+
dependencies=Dependencies(refs={"model_a", "model_d"}), # Test references both A and D
175+
)
176+
177+
# Place tests that would create indirect cycles when combined with model dependencies
178+
model_a.tests = [test_a_refs_d]
179+
model_b.tests = []
180+
assert model_c.tests == []
181+
assert model_d.tests == []
182+
183+
# Fix circular references on model_a
184+
model_a.fix_circular_test_refs(context)
185+
# The test should be moved from model_a to break the long indirect cycle down to model_d
186+
assert model_a.tests == []
187+
assert model_d.tests == [test_a_refs_d]
188+
189+
# Test on model_b that references model_d (creates indirect cycle through test dependencies)
190+
# The cycle would be: model_b (via test) -> model_d -> model_c -> model_b
191+
test_b_refs_d = TestConfig(
192+
name="test_b_refs_d",
193+
sql="",
194+
dependencies=Dependencies(refs={"model_b", "model_d"}), # Test references both B and D
195+
)
196+
model_a.tests = []
197+
model_b.tests = [test_b_refs_d]
198+
model_c.tests = []
199+
model_d.tests = []
200+
201+
model_b.fix_circular_test_refs(context)
202+
assert model_a.tests == []
203+
assert model_b.tests == []
204+
assert model_c.tests == []
205+
assert model_d.tests == [test_b_refs_d]
206+
207+
# Do both at the same time
208+
model_a.tests = [test_a_refs_d]
209+
model_b.tests = [test_b_refs_d]
210+
model_c.tests = []
211+
model_d.tests = []
212+
213+
model_a.fix_circular_test_refs(context)
214+
model_b.fix_circular_test_refs(context)
215+
assert model_a.tests == []
216+
assert model_b.tests == []
217+
assert model_c.tests == []
218+
assert model_d.tests == [test_a_refs_d, test_b_refs_d]
219+
220+
# Test A -> B -> C cycle and make sure test ends up with C
221+
test_a_refs_c = TestConfig(
222+
name="test_a_refs_c",
223+
sql="",
224+
dependencies=Dependencies(refs={"model_a", "model_c"}), # Test references both A and C
225+
)
226+
model_a.tests = [test_a_refs_c]
227+
model_b.tests = []
228+
model_c.tests = []
229+
model_d.tests = []
230+
231+
model_a.fix_circular_test_refs(context)
232+
assert model_a.tests == []
233+
assert model_b.tests == []
234+
assert model_c.tests == [test_a_refs_c]
235+
assert model_d.tests == []
236+
237+
118238
@pytest.mark.slow
119239
def test_load_invalid_ref_audit_constraints(
120240
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project

0 commit comments

Comments
 (0)