diff --git a/src/impulse/application/use_cases.py b/src/impulse/application/use_cases.py index bf79454..1134220 100644 --- a/src/impulse/application/use_cases.py +++ b/src/impulse/application/use_cases.py @@ -1,5 +1,6 @@ from collections.abc import Set from collections.abc import Callable +import fnmatch import itertools import grimp from impulse import ports, dotfile @@ -14,6 +15,19 @@ def draw_graph( get_top_level_package: Callable[[str], str], build_graph: Callable[[str], grimp.ImportGraph], viewer: ports.GraphViewer, + depth: int = 1, +<<<<<<< HEAD +<<<<<<< HEAD + hide_unlinked: bool = False, +<<<<<<< HEAD +======= +>>>>>>> 6e23bb0 (depth) +======= + hide_unlinked: bool = False, +>>>>>>> 39624fc (hide unlinked) +======= + hide_nodes_patterns: list[str] | None = None, +>>>>>>> cbfc506 (hide nodes) ) -> None: """ Create a file showing a graph of the supplied package. @@ -28,6 +42,19 @@ def draw_graph( build_graph: the function which builds the graph of the supplied package (pass grimp.build_graph or a test double). viewer: GraphViewer for generating the graph image and opening it. + depth: the depth of submodules to include in the graph (default: 1 for direct children). +<<<<<<< HEAD +<<<<<<< HEAD + hide_unlinked: whether to hide nodes that have no incoming or outgoing edges. +<<<<<<< HEAD +======= +>>>>>>> 6e23bb0 (depth) +======= + hide_unlinked: whether to hide nodes that have no incoming or outgoing edges. +>>>>>>> 39624fc (hide unlinked) +======= + hide_nodes_patterns: list of fnmatch patterns to hide matching nodes. +>>>>>>> cbfc506 (hide nodes) """ # Add current directory to the path, as this doesn't happen automatically. sys_path.insert(0, current_directory) @@ -35,30 +62,153 @@ def draw_graph( top_level_package = get_top_level_package(module_name) grimp_graph = build_graph(top_level_package) - dot = _build_dot(grimp_graph, module_name, show_import_totals, show_cycle_breakers) +<<<<<<< HEAD +<<<<<<< HEAD + dot = _build_dot( + grimp_graph, module_name, show_import_totals, show_cycle_breakers, depth, hide_unlinked, + hide_nodes_patterns=hide_nodes_patterns or [], + ) +======= + dot = _build_dot(grimp_graph, module_name, show_import_totals, show_cycle_breakers, depth) +>>>>>>> 6e23bb0 (depth) +======= + dot = _build_dot( + grimp_graph, module_name, show_import_totals, show_cycle_breakers, depth, hide_unlinked + ) +>>>>>>> 39624fc (hide unlinked) viewer.view(dot) +def _find_modules_up_to_depth( + grimp_graph: grimp.ImportGraph, module_name: str, depth: int +) -> Set[str]: + """ + Find all modules up to and including the specified depth below the given module. + + For depth=1, returns direct children. + For depth=2, returns direct children AND grandchildren. + And so on. + """ + if depth < 1: + raise ValueError("Depth must be at least 1") + + all_modules: set[str] = set() + current_level = {module_name} + + for _ in range(depth): + next_level: set[str] = set() + for mod in current_level: + next_level.update(grimp_graph.find_children(mod)) + all_modules.update(next_level) + current_level = next_level + + return all_modules + + +def _should_hide_node(module: str, base_module: str, patterns: list[str]) -> bool: + """ + Check if a module should be hidden based on fnmatch patterns. + + Patterns are matched against the relative module name (without leading dot). + For example, if base_module is "mypackage.foo" and module is "mypackage.foo.bar.baz", + the relative name is "bar.baz" and patterns like "bar" or "bar.*" are matched against it. + """ + if not patterns: + return False + + # Get the relative module name (strip the base module prefix) + if module.startswith(base_module + "."): + relative = module[len(base_module) + 1:] # Strip base module and the dot + else: + # Fallback: use just the last component + relative = module.split(".")[-1] + + # Check if any pattern matches + for pattern in patterns: + if fnmatch.fnmatch(relative, pattern): + return True + + return False + + class _DotGraphBuildStrategy: +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD + def __init__(self, depth: int = 1, hide_unlinked: bool = False) -> None: + self.depth = depth + self.hide_unlinked = hide_unlinked +======= + def __init__(self, depth: int = 1) -> None: + self.depth = depth +>>>>>>> 6e23bb0 (depth) +======= + def __init__(self, depth: int = 1, hide_unlinked: bool = False) -> None: + self.depth = depth + self.hide_unlinked = hide_unlinked +>>>>>>> 39624fc (hide unlinked) +======= + def __init__(self, depth: int = 1, hide_unlinked: bool = False, hide_nodes_patterns: list[str] | None = None) -> None: + self.depth = depth + self.hide_unlinked = hide_unlinked + self.hide_nodes_patterns = hide_nodes_patterns or [] +>>>>>>> cbfc506 (hide nodes) + def build(self, module_name: str, grimp_graph: grimp.ImportGraph) -> dotfile.DotGraph: - children = grimp_graph.find_children(module_name) + modules = _find_modules_up_to_depth(grimp_graph, module_name, self.depth) + + # Filter out hidden nodes based on patterns + if self.hide_nodes_patterns: + modules = { + mod for mod in modules + if not _should_hide_node(mod, module_name, self.hide_nodes_patterns) + } - self.prepare_graph(grimp_graph, children) + self.prepare_graph(grimp_graph, modules) - dot = dotfile.DotGraph(title=module_name, concentrate=self.should_concentrate()) - for child in children: - dot.add_node(child) - for upstream, downstream in itertools.permutations(children, r=2): + dot = dotfile.DotGraph( + title=module_name, concentrate=self.should_concentrate(), depth=self.depth + ) +<<<<<<< HEAD +<<<<<<< HEAD + + # Build edges first so we can determine which nodes are linked + edges: list[dotfile.Edge] = [] +======= + for mod in modules: + dot.add_node(mod) +>>>>>>> 6e23bb0 (depth) +======= + + # Build edges first so we can determine which nodes are linked + edges: list[dotfile.Edge] = [] +>>>>>>> 39624fc (hide unlinked) + for upstream, downstream in itertools.permutations(modules, r=2): if edge := self.build_edge(grimp_graph, upstream, downstream): - dot.add_edge(edge) + edges.append(edge) + + # Determine which nodes have at least one connection + if self.hide_unlinked: + linked_nodes: set[str] = set() + for edge in edges: + linked_nodes.add(edge.source) + linked_nodes.add(edge.destination) + nodes_to_add = modules & linked_nodes + else: + nodes_to_add = modules + + for mod in nodes_to_add: + dot.add_node(mod) + for edge in edges: + dot.add_edge(edge) return dot def should_concentrate(self) -> bool: return True - def prepare_graph(self, grimp_graph: grimp.ImportGraph, children: Set[str]) -> None: + def prepare_graph(self, grimp_graph: grimp.ImportGraph, modules: Set[str]) -> None: pass def build_edge( @@ -70,9 +220,26 @@ def build_edge( class _ModuleSquashingBuildStrategy(_DotGraphBuildStrategy): """Fast builder for when we don't need additional data about the imports.""" - def prepare_graph(self, grimp_graph: grimp.ImportGraph, children: Set[str]) -> None: - for child in children: - grimp_graph.squash_module(child) +<<<<<<< HEAD +<<<<<<< HEAD +<<<<<<< HEAD + def __init__(self, depth: int = 1, hide_unlinked: bool = False) -> None: + super().__init__(depth=depth, hide_unlinked=hide_unlinked) +======= + def __init__(self, depth: int = 1, hide_unlinked: bool = False, hide_nodes_patterns: list[str] | None = None) -> None: + super().__init__(depth=depth, hide_unlinked=hide_unlinked, hide_nodes_patterns=hide_nodes_patterns) +>>>>>>> cbfc506 (hide nodes) + +======= +>>>>>>> 6e23bb0 (depth) +======= + def __init__(self, depth: int = 1, hide_unlinked: bool = False) -> None: + super().__init__(depth=depth, hide_unlinked=hide_unlinked) + +>>>>>>> 39624fc (hide unlinked) + def prepare_graph(self, grimp_graph: grimp.ImportGraph, modules: Set[str]) -> None: + for mod in modules: + grimp_graph.squash_module(mod) def build_edge( self, grimp_graph: grimp.ImportGraph, upstream: str, downstream: str @@ -84,12 +251,35 @@ def build_edge( class _ImportExpressionBuildStrategy(_DotGraphBuildStrategy): """Slower builder for when we want to work on the whole graph, - without squashing children. + without squashing modules. """ def __init__( - self, *, module_name: str, show_import_totals: bool, show_cycle_breakers: bool + self, + *, + module_name: str, + show_import_totals: bool, + show_cycle_breakers: bool, + depth: int = 1, +<<<<<<< HEAD +<<<<<<< HEAD + hide_unlinked: bool = False, + hide_nodes_patterns: list[str] | None = None, + ) -> None: +<<<<<<< HEAD + super().__init__(depth=depth, hide_unlinked=hide_unlinked) +======= + ) -> None: + super().__init__(depth=depth) +>>>>>>> 6e23bb0 (depth) +======= + hide_unlinked: bool = False, ) -> None: + super().__init__(depth=depth, hide_unlinked=hide_unlinked) +>>>>>>> 39624fc (hide unlinked) +======= + super().__init__(depth=depth, hide_unlinked=hide_unlinked, hide_nodes_patterns=hide_nodes_patterns) +>>>>>>> cbfc506 (hide nodes) self.module_name = module_name self.show_import_totals = show_import_totals self.show_cycle_breakers = show_cycle_breakers @@ -99,22 +289,22 @@ def should_concentrate(self) -> bool: # We need to see edge direction emphasized separately. return not (self.show_import_totals or self.show_cycle_breakers) - def prepare_graph(self, grimp_graph: grimp.ImportGraph, children: Set[str]) -> None: - super().prepare_graph(grimp_graph, children) + def prepare_graph(self, grimp_graph: grimp.ImportGraph, modules: Set[str]) -> None: + super().prepare_graph(grimp_graph, modules) if self.show_cycle_breakers: - self.cycle_breakers = self._get_coarse_grained_cycle_breakers(grimp_graph, children) + self.cycle_breakers = self._get_coarse_grained_cycle_breakers(grimp_graph, modules) def _get_coarse_grained_cycle_breakers( - self, grimp_graph: grimp.ImportGraph, children: Set[str] + self, grimp_graph: grimp.ImportGraph, modules: Set[str] ) -> set[tuple[str, str]]: # In the form (importer, imported). coarse_grained_cycle_breakers: set[tuple[str, str]] = set() for fine_grained_cycle_breaker in grimp_graph.nominate_cycle_breakers(self.module_name): importer, imported = fine_grained_cycle_breaker - importer_ancestor = self._get_self_or_ancestor(candidate=importer, ancestors=children) - imported_ancestor = self._get_self_or_ancestor(candidate=imported, ancestors=children) + importer_ancestor = self._get_self_or_ancestor(candidate=importer, ancestors=modules) + imported_ancestor = self._get_self_or_ancestor(candidate=imported, ancestors=modules) if importer_ancestor and imported_ancestor: coarse_grained_cycle_breakers.add((importer_ancestor, imported_ancestor)) @@ -131,9 +321,19 @@ def _get_self_or_ancestor(candidate: str, ancestors: Set[str]) -> str | None: def build_edge( self, grimp_graph: grimp.ImportGraph, upstream: str, downstream: str ) -> dotfile.Edge | None: - if grimp_graph.direct_import_exists( - importer=downstream, imported=upstream, as_packages=True - ): + # For depth > 1, we can't use as_packages=True because modules may share + # descendants (e.g., foo.blue and foo.blue.alpha are both in our set). + # In that case, only check for direct imports between exact modules. + if self.depth > 1: + import_exists = grimp_graph.direct_import_exists( + importer=downstream, imported=upstream, as_packages=False + ) + else: + import_exists = grimp_graph.direct_import_exists( + importer=downstream, imported=upstream, as_packages=True + ) + + if import_exists: if self.show_import_totals: number_of_imports = self._count_imports_between_packages( grimp_graph, importer=downstream, imported=upstream @@ -183,15 +383,51 @@ def _build_dot( module_name: str, show_import_totals: bool, show_cycle_breakers: bool, + depth: int = 1, +<<<<<<< HEAD +<<<<<<< HEAD + hide_unlinked: bool = False, +<<<<<<< HEAD +======= +>>>>>>> 6e23bb0 (depth) +======= + hide_unlinked: bool = False, +>>>>>>> 39624fc (hide unlinked) +======= + hide_nodes_patterns: list[str] | None = None, +>>>>>>> cbfc506 (hide nodes) ) -> dotfile.DotGraph: strategy: _DotGraphBuildStrategy - if show_import_totals or show_cycle_breakers: + # Use ImportExpressionBuildStrategy when: + # - show_import_totals or show_cycle_breakers is enabled, OR + # - depth > 1 (squashing would remove deeper modules we want to show) + if show_import_totals or show_cycle_breakers or depth > 1: strategy = _ImportExpressionBuildStrategy( module_name=module_name, show_import_totals=show_import_totals, show_cycle_breakers=show_cycle_breakers, + depth=depth, +<<<<<<< HEAD +<<<<<<< HEAD + hide_unlinked=hide_unlinked, + hide_nodes_patterns=hide_nodes_patterns, + ) + else: +<<<<<<< HEAD + strategy = _ModuleSquashingBuildStrategy(depth=depth, hide_unlinked=hide_unlinked) +======= + ) + else: + strategy = _ModuleSquashingBuildStrategy(depth=depth) +>>>>>>> 6e23bb0 (depth) +======= + hide_unlinked=hide_unlinked, ) else: - strategy = _ModuleSquashingBuildStrategy() + strategy = _ModuleSquashingBuildStrategy(depth=depth, hide_unlinked=hide_unlinked) +>>>>>>> 39624fc (hide unlinked) +======= + strategy = _ModuleSquashingBuildStrategy(depth=depth, hide_unlinked=hide_unlinked, hide_nodes_patterns=hide_nodes_patterns) +>>>>>>> cbfc506 (hide nodes) return strategy.build(module_name, grimp_graph) diff --git a/src/impulse/cli.py b/src/impulse/cli.py index 347a8fe..77e423e 100644 --- a/src/impulse/cli.py +++ b/src/impulse/cli.py @@ -35,6 +35,27 @@ def main(): help="Output format (default to html).", ) @click.option("--force-console", is_flag=True, help="Force the use of the console output.") +@click.option( + "--hide-unlinked", + is_flag=True, + help="Hide nodes that have no incoming or outgoing edges.", +) +@click.option( + "--hide-nodes", + type=str, + default="", + help=( + "Comma-separated list of fnmatch patterns to hide nodes. " + "Patterns are matched against relative module names (without leading dot). " + "Example: --hide-nodes=foo,bar.* hides .foo, .bar.plop, .bar.plip.plup" + ), +) +@click.option( + "--depth", + type=int, + default=1, + help="Depth of submodules to include in the graph (default: 1 for direct children).", +) @click.argument("module_name", type=str) def drawgraph( module_name: str, @@ -42,7 +63,15 @@ def drawgraph( show_cycle_breakers: bool, force_console: bool, format: str, + hide_unlinked: bool, + hide_nodes: str, + depth: int, ) -> None: + # Parse hide_nodes patterns (comma-separated list) + hide_nodes_patterns = ( + [p.strip() for p in hide_nodes.split(",") if p.strip()] if hide_nodes else [] + ) + viewer: ports.GraphViewer if format == "html": if not force_console and sys.stdout.isatty(): @@ -58,6 +87,9 @@ def drawgraph( module_name=module_name, show_import_totals=show_import_totals, show_cycle_breakers=show_cycle_breakers, + hide_unlinked=hide_unlinked, + hide_nodes_patterns=hide_nodes_patterns, + depth=depth, sys_path=sys.path, current_directory=os.getcwd(), get_top_level_package=adapters.get_top_level_package, diff --git a/src/impulse/dotfile.py b/src/impulse/dotfile.py index fa40d2f..5b75f9a 100644 --- a/src/impulse/dotfile.py +++ b/src/impulse/dotfile.py @@ -10,7 +10,10 @@ class Edge: emphasized: bool = False def __str__(self) -> str: - return f'"{DotGraph.render_module(self.source)}" -> "{DotGraph.render_module(self.destination)}"{self._render_attrs()}\n' + return self.render(base_module="") + + def render(self, base_module: str) -> str: + return f'"{DotGraph.render_module(self.source, base_module)}" -> "{DotGraph.render_module(self.destination, base_module)}"{self._render_attrs()}\n' def _render_attrs(self) -> str: attrs: dict[str, str] = {} @@ -32,11 +35,12 @@ class DotGraph: https://en.wikipedia.org/wiki/DOT_(graph_description_language) """ - def __init__(self, title: str, concentrate: bool = True) -> None: + def __init__(self, title: str, concentrate: bool = True, depth: int = 1) -> None: self.title = title self.nodes: set[str] = set() self.edges: set[Edge] = set() self.concentrate = concentrate + self.depth = depth def add_node(self, name: str) -> None: self.nodes.add(name) @@ -54,12 +58,19 @@ def render(self) -> str: }}""") def _render_nodes(self) -> str: - return "\n".join(f'"{self.render_module(node)}"\n' for node in sorted(self.nodes)) + return "\n".join( + f'"{self.render_module(node, self.title)}"\n' for node in sorted(self.nodes) + ) def _render_edges(self) -> str: - return "\n".join(str(edge) for edge in sorted(self.edges)) + return "\n".join(edge.render(self.title) for edge in sorted(self.edges)) @staticmethod - def render_module(module: str) -> str: - # Render as relative module. - return f".{module.split('.')[-1]}" + def render_module(module: str, base_module: str = "") -> str: + # Render as relative module by stripping the base module prefix. + if base_module and module.startswith(base_module + "."): + relative = module[len(base_module) :] + return relative # Already starts with "." + else: + # Fallback: show as relative with just the last component + return f".{module.split('.')[-1]}" diff --git a/tests/unit/application/test_use_cases.py b/tests/unit/application/test_use_cases.py index 95badd0..282f32d 100644 --- a/tests/unit/application/test_use_cases.py +++ b/tests/unit/application/test_use_cases.py @@ -179,3 +179,290 @@ def test_draw_graph_show_cycle_breakers(self): ), Edge("mypackage.foo.red", "mypackage.foo.blue", emphasized=True), } + + def test_draw_graph_depth_2(self): + """Test that depth=2 shows children AND grandchildren of the module.""" + + def build_graph_with_depth(package_name: str) -> grimp.ImportGraph: + graph = grimp.ImportGraph() + graph.add_module(package_name) + graph.add_module(SOME_MODULE) + + # Create a hierarchy: foo.blue, foo.green, foo.blue.alpha, foo.blue.beta, foo.green.gamma + for child in ("blue", "green"): + graph.add_module(f"{SOME_MODULE}.{child}") + for grandchild in ("alpha", "beta"): + graph.add_module(f"{SOME_MODULE}.blue.{grandchild}") + graph.add_module(f"{SOME_MODULE}.green.gamma") + + # Add imports at the grandchild level + graph.add_import( + importer=f"{SOME_MODULE}.blue.alpha", + imported=f"{SOME_MODULE}.green.gamma", + ) + graph.add_import( + importer=f"{SOME_MODULE}.blue.beta", + imported=f"{SOME_MODULE}.green.gamma", + ) + # Add import at the child level + graph.add_import( + importer=f"{SOME_MODULE}.blue", + imported=f"{SOME_MODULE}.green", + ) + return graph + + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_graph_with_depth, + viewer=viewer, + depth=2, + ) + + assert viewer.called_with_dot.depth == 2 + # depth=2 includes both children (depth 1) AND grandchildren (depth 2) + assert viewer.called_with_dot.nodes == { + "mypackage.foo.blue", + "mypackage.foo.green", + "mypackage.foo.blue.alpha", + "mypackage.foo.blue.beta", + "mypackage.foo.green.gamma", + } + assert viewer.called_with_dot.edges == { + Edge("mypackage.foo.blue", "mypackage.foo.green"), + Edge("mypackage.foo.blue.alpha", "mypackage.foo.green.gamma"), + Edge("mypackage.foo.blue.beta", "mypackage.foo.green.gamma"), + } +<<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> 39624fc (hide unlinked) + + def test_draw_graph_hide_unlinked(self): + """Test that hide_unlinked=True removes nodes with no edges.""" + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_fake_graph, + viewer=viewer, + hide_unlinked=True, + ) + + # The default graph has blue, green, yellow, red nodes + # yellow has no direct edges (only green -> yellow), but it IS connected + # All four nodes are connected in the test graph, so all should remain + assert viewer.called_with_dot.nodes == { + "mypackage.foo.green", + "mypackage.foo.blue", + "mypackage.foo.yellow", + "mypackage.foo.red", + } + + def test_draw_graph_hide_unlinked_removes_isolated_nodes(self): + """Test that hide_unlinked=True removes truly isolated nodes.""" + + def build_graph_with_isolated(package_name: str) -> grimp.ImportGraph: + graph = grimp.ImportGraph() + graph.add_module(package_name) + graph.add_module(SOME_MODULE) + + # Create some children, one of which is isolated + for child in ("blue", "green", "isolated"): + graph.add_module(f"{SOME_MODULE}.{child}") + + # Only blue and green are connected + graph.add_import( + importer=f"{SOME_MODULE}.blue", + imported=f"{SOME_MODULE}.green", + ) + # "isolated" has no imports + return graph + + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_graph_with_isolated, + viewer=viewer, + hide_unlinked=True, + ) + + # Only blue and green should remain; isolated should be filtered out + assert viewer.called_with_dot.nodes == { + "mypackage.foo.blue", + "mypackage.foo.green", + } + assert viewer.called_with_dot.edges == { + Edge("mypackage.foo.blue", "mypackage.foo.green"), + } +<<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> 6e23bb0 (depth) +======= +>>>>>>> 39624fc (hide unlinked) +======= + + def test_draw_graph_hide_nodes_exact_match(self): + """Test that hide_nodes_patterns hides exact matching nodes.""" + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_fake_graph, + viewer=viewer, + hide_nodes_patterns=["blue"], + ) + + # "blue" should be hidden, but "blue.alpha" etc. should NOT be hidden + # since the pattern "blue" only matches exactly "blue" + assert "mypackage.foo.blue" not in viewer.called_with_dot.nodes + assert "mypackage.foo.green" in viewer.called_with_dot.nodes + assert "mypackage.foo.yellow" in viewer.called_with_dot.nodes + assert "mypackage.foo.red" in viewer.called_with_dot.nodes + + def test_draw_graph_hide_nodes_wildcard_pattern(self): + """Test that hide_nodes_patterns with wildcard hides matching nodes.""" + + def build_graph_for_wildcard(package_name: str) -> grimp.ImportGraph: + graph = grimp.ImportGraph() + graph.add_module(package_name) + graph.add_module(SOME_MODULE) + + # Create a hierarchy to test wildcard patterns + for child in ("foo", "bar", "baz", "other"): + graph.add_module(f"{SOME_MODULE}.{child}") + + graph.add_import( + importer=f"{SOME_MODULE}.foo", + imported=f"{SOME_MODULE}.other", + ) + return graph + + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_graph_for_wildcard, + viewer=viewer, + hide_nodes_patterns=["ba*"], # Should hide "bar" and "baz" + ) + + assert viewer.called_with_dot.nodes == { + "mypackage.foo.foo", + "mypackage.foo.other", + } + + def test_draw_graph_hide_nodes_nested_wildcard(self): + """Test that hide_nodes_patterns with nested wildcard hides matching nodes.""" + + def build_graph_for_nested(package_name: str) -> grimp.ImportGraph: + graph = grimp.ImportGraph() + graph.add_module(package_name) + graph.add_module(SOME_MODULE) + + # Create depth 2 hierarchy + for child in ("bar", "other"): + graph.add_module(f"{SOME_MODULE}.{child}") + for grandchild in ("plop", "plip"): + graph.add_module(f"{SOME_MODULE}.bar.{grandchild}") + graph.add_module(f"{SOME_MODULE}.other.thing") + + graph.add_import( + importer=f"{SOME_MODULE}.bar.plop", + imported=f"{SOME_MODULE}.other.thing", + ) + return graph + + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_graph_for_nested, + viewer=viewer, + depth=2, + hide_nodes_patterns=["bar.*"], # Should hide "bar.plop" and "bar.plip" + ) + + # bar.plop and bar.plip should be hidden + assert "mypackage.foo.bar.plop" not in viewer.called_with_dot.nodes + assert "mypackage.foo.bar.plip" not in viewer.called_with_dot.nodes + # bar and other.thing should remain + assert "mypackage.foo.bar" in viewer.called_with_dot.nodes + assert "mypackage.foo.other" in viewer.called_with_dot.nodes + assert "mypackage.foo.other.thing" in viewer.called_with_dot.nodes + + def test_draw_graph_hide_nodes_multiple_patterns(self): + """Test that multiple hide_nodes_patterns work together.""" + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_fake_graph, + viewer=viewer, + hide_nodes_patterns=["blue", "red"], # Hide both blue and red + ) + + assert viewer.called_with_dot.nodes == { + "mypackage.foo.green", + "mypackage.foo.yellow", + } + + def test_draw_graph_hide_nodes_removes_corresponding_edges(self): + """Test that hiding nodes also removes edges to/from those nodes.""" + viewer = SpyGraphViewer() + + use_cases.draw_graph( + SOME_MODULE, + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, + build_graph=build_fake_graph, + viewer=viewer, + hide_nodes_patterns=["blue"], + ) + + # Edges involving blue should be gone + for edge in viewer.called_with_dot.edges: + assert "blue" not in edge.source + assert "blue" not in edge.destination +>>>>>>> cbfc506 (hide nodes)