diff --git a/sqlmesh/utils/dag.py b/sqlmesh/utils/dag.py index 1274e0616b..c39fd2a1d2 100644 --- a/sqlmesh/utils/dag.py +++ b/sqlmesh/utils/dag.py @@ -99,6 +99,53 @@ def upstream(self, node: T) -> t.Set[T]: return self._upstream[node] + def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]: + """Find the exact cycle path using DFS when a cycle is detected. + + Args: + nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies + + Returns: + List of nodes forming the cycle path, or None if no cycle found + """ + if not nodes_in_cycle: + return None + + # Use DFS to find a cycle path + visited: t.Set[T] = set() + path: t.List[T] = [] + + def dfs(node: T) -> t.Optional[t.List[T]]: + if node in path: + # Found a cycle - extract the cycle path + cycle_start = path.index(node) + return path[cycle_start:] + [node] + + if node in visited: + return None + + visited.add(node) + path.append(node) + + # Only follow edges to nodes that are still in the unprocessed set + for neighbor in nodes_in_cycle.get(node, set()): + if neighbor in nodes_in_cycle: + cycle = dfs(neighbor) + if cycle: + return cycle + + path.pop() + return None + + # Try starting DFS from each unvisited node + for start_node in nodes_in_cycle: + if start_node not in visited: + cycle = dfs(start_node) + if cycle: + return cycle[:-1] # Remove the duplicate node at the end + + return None + @property def roots(self) -> t.Set[T]: """Returns all nodes in the graph without any upstream dependencies.""" @@ -125,23 +172,31 @@ def sorted(self) -> t.List[T]: next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} if not next_nodes: - # Sort cycle candidates to make the order deterministic - cycle_candidates_msg = ( - "\nPossible candidates to check for circular references: " - + ", ".join(str(node) for node in sorted(cycle_candidates)) - ) + # A cycle was detected - find the exact cycle path + cycle_path = self._find_cycle_path(unprocessed_nodes) - if last_processed_nodes: - last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( - str(node) for node in last_processed_nodes + last_processed_msg = "" + if cycle_path: + node_output = " ->\n".join( + str(node) for node in (cycle_path + [cycle_path[0]]) ) + cycle_msg = f"\nCycle:\n{node_output}" else: - last_processed_msg = "" + # Fallback message in case a cycle can't be found + cycle_candidates_msg = ( + "\nPossible candidates to check for circular references: " + + ", ".join(str(node) for node in sorted(cycle_candidates)) + ) + cycle_msg = cycle_candidates_msg + if last_processed_nodes: + last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( + str(node) for node in last_processed_nodes + ) raise SQLMeshError( "Detected a cycle in the DAG. " "Please make sure there are no circular references between nodes." - f"{last_processed_msg}{cycle_candidates_msg}" + f"{last_processed_msg}{cycle_msg}" ) for node in next_nodes: diff --git a/tests/utils/test_dag.py b/tests/utils/test_dag.py index 444e78555c..7c142ee4a0 100644 --- a/tests/utils/test_dag.py +++ b/tests/utils/test_dag.py @@ -57,8 +57,7 @@ def test_sorted_with_cycles(): expected_error_message = ( "Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n" - "Last nodes added to the DAG: c\n" - "Possible candidates to check for circular references: d, e" + "Cycle:\nd ->\ne ->\nd" ) assert expected_error_message == str(ex.value) @@ -70,7 +69,7 @@ def test_sorted_with_cycles(): expected_error_message = ( "Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n" - "Possible candidates to check for circular references: a, b, c" + "Cycle:\na ->\nb ->\nc ->\na" ) assert expected_error_message == str(ex.value) @@ -81,11 +80,11 @@ def test_sorted_with_cycles(): dag.sorted expected_error_message = ( - "Last nodes added to the DAG: c\n" - + "Possible candidates to check for circular references: b, d" + "Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n" + + "Cycle:\nb ->\nd ->\nb" ) - assert expected_error_message in str(ex.value) + assert expected_error_message == str(ex.value) def test_reversed_graph():