3030from sqlmesh .dbt .test import TestConfig
3131from sqlmesh .dbt .util import DBT_VERSION
3232from sqlmesh .utils import AttributeDict
33- from sqlmesh .utils .dag import find_path_with_dfs
3433from sqlmesh .utils .errors import ConfigError
3534from sqlmesh .utils .pydantic import field_validator
3635
@@ -271,10 +270,9 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
271270
272271 def fix_circular_test_refs (self , context : DbtContext ) -> None :
273272 """
274- Checks for circular references between models and moves tests to break cycles.
275- This handles both direct circular references (A -> B -> A) and indirect circular
276- references (A -> B -> C -> A). Tests are moved to the model that appears latest
277- in the dependency chain to ensure the cycle is broken.
273+ Checks for direct circular references between two models and moves the test to the downstream
274+ model if found. This addresses the most common circular reference - relationship tests in both
275+ directions. In the future, we may want to increase coverage by checking for indirect circular references.
278276
279277 Args:
280278 context: The dbt context this model resides within.
@@ -286,91 +284,16 @@ def fix_circular_test_refs(self, context: DbtContext) -> None:
286284 for ref in test .dependencies .refs :
287285 if ref == self .name or ref in self .dependencies .refs :
288286 continue
289-
290- # Check if moving this test would create or maintain a cycle
291- cycle_path = self ._find_circular_path (ref , context , set ())
292- if cycle_path :
293- # Find the model in the cycle that should receive the test
294- # We want to move to the model that appears latest in the dependency chain
295- target_model_name = self ._select_target_model_for_test (cycle_path , context )
296- target_model = context .refs [target_model_name ]
297-
287+ model = context .refs [ref ]
288+ if (
289+ self .name in model .dependencies .refs
290+ or self .name in model .tests_ref_source_dependencies .refs
291+ ):
298292 logger .info (
299- f"Moving test '{ test .name } ' from model '{ self .name } ' to '{ target_model_name } ' "
300- f"to avoid circular reference through path: { ' -> ' .join (cycle_path )} "
293+ f"Moving test '{ test .name } ' from model '{ self .name } ' to '{ model .name } ' to avoid circular reference."
301294 )
302- target_model .tests .append (test )
295+ model .tests .append (test )
303296 self .tests .remove (test )
304- break
305-
306- def _find_circular_path (
307- self , ref : str , context : DbtContext , visited : t .Set [str ]
308- ) -> t .Optional [t .List [str ]]:
309- """
310- Find if there's a circular dependency path from ref back to this model.
311-
312- Args:
313- ref: The model name to start searching from
314- context: The dbt context
315- visited: Set of model names already visited in this path
316-
317- Returns:
318- List of model names forming the circular path, or None if no cycle exists
319- """
320- # Build a graph of all models and their dependencies from the context
321- graph : t .Dict [str , t .Set [str ]] = {}
322-
323- def build_graph_from_node (node_name : str , current_visited : t .Set [str ]) -> None :
324- if node_name in current_visited or node_name in graph :
325- return
326- current_visited .add (node_name )
327-
328- model = context .refs [node_name ]
329- # Include both direct model dependencies and test dependencies
330- all_refs = model .dependencies .refs | model .tests_ref_source_dependencies .refs
331- graph [node_name ] = all_refs .copy ()
332-
333- # Recursively build graph for dependencies
334- for dep in all_refs :
335- build_graph_from_node (dep , current_visited )
336-
337- # Build the graph starting from the ref, including visited nodes to avoid infinite recursion
338- build_graph_from_node (ref , visited .copy ())
339-
340- # Add self.name to the graph if it's not already there
341- if self .name not in graph :
342- graph [self .name ] = set ()
343-
344- # Use the shared DFS function to find path from ref to self.name
345- return find_path_with_dfs (graph , start_node = ref , target_node = self .name )
346-
347- def _select_target_model_for_test (self , cycle_path : t .List [str ], context : DbtContext ) -> str :
348- """
349- Select which model in the cycle should receive the test.
350- We select the model that has the most downstream dependencies in the cycle
351-
352- Args:
353- cycle_path: List of model names in the circular dependency path
354- context: The dbt context
355-
356- Returns:
357- Name of the model that should receive the test
358- """
359- # Count how many other models in the cycle each model depends on
360- dependency_counts = {}
361-
362- for model_name in cycle_path :
363- model = context .refs [model_name ]
364- all_refs = model .dependencies .refs | model .tests_ref_source_dependencies .refs
365- count = len ([ref for ref in all_refs if ref in cycle_path ])
366- dependency_counts [model_name ] = count
367-
368- # Return the model with the fewest dependencies within the cycle
369- # (i.e., the most downstream model in the cycle)
370- if dependency_counts :
371- return min (dependency_counts , key = dependency_counts .get ) # type: ignore
372- # Fallback to the last model in the path
373- return cycle_path [- 1 ]
374297
375298 @property
376299 def sqlmesh_config_fields (self ) -> t .Set [str ]:
0 commit comments