From c6e80b7642cc6c308507aa6e9c21025532c9fb3a Mon Sep 17 00:00:00 2001 From: David Seddon Date: Fri, 31 Oct 2025 12:25:23 +0000 Subject: [PATCH] Support namespace packages --- CHANGELOG.rst | 1 + justfile | 1 + src/impulse/adapters.py | 19 ++++++++++++++ src/impulse/application/use_cases.py | 7 +++-- src/impulse/cli.py | 1 + tests/unit/application/test_use_cases.py | 33 ++++++++++++++++++++++++ 6 files changed, 60 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0c0d54d..09576b4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,7 @@ latest ------ * Add --show-cycle-breakers flag. +* Support namespace packages. 2.0 (2025-10-20) ---------------- diff --git a/justfile b/justfile index c41a162..1782128 100644 --- a/justfile +++ b/justfile @@ -11,6 +11,7 @@ install-precommit: test: @uv run pytest @uv run impulse drawgraph grimp + @uv run --with=google-cloud-audit-log impulse drawgraph google.cloud.audit @uv run impulse drawgraph grimp --show-import-totals @uv run --with=django impulse drawgraph django.db --show-cycle-breakers diff --git a/src/impulse/adapters.py b/src/impulse/adapters.py index 45eb041..c404048 100644 --- a/src/impulse/adapters.py +++ b/src/impulse/adapters.py @@ -3,6 +3,7 @@ from impulse import ports from textwrap import dedent from impulse import dotfile +import importlib class BrowserGraphViewer(ports.GraphViewer): @@ -169,3 +170,21 @@ def view(self, dot: dotfile.DotGraph) -> None: # Open in browser webbrowser.open(f"file://{html_path}") + + +def get_top_level_package(module_name: str) -> str: + """ + Returns the top-level package name from the given module name. + + This will usually be the first part of the dotted module name (before the first dot), but for namespace packages + it will be the 'portion' name. + """ + + # Successively work through the module components until we encounter one with a corresponding file. + components = module_name.split(".") + for level in range(len(components)): + candidate_name = ".".join(components[: level + 1]) + candidate = importlib.import_module(candidate_name) + if candidate.__file__: + return candidate_name + raise ImportError(f"Can't import module '{module_name}'. Is it on the Python path?") diff --git a/src/impulse/application/use_cases.py b/src/impulse/application/use_cases.py index f329ffd..071ff2f 100644 --- a/src/impulse/application/use_cases.py +++ b/src/impulse/application/use_cases.py @@ -12,6 +12,7 @@ def draw_graph( show_cycle_breakers: bool, sys_path: list[str], current_directory: str, + get_top_level_package: Callable[[str], str], build_graph: Callable[[str], grimp.ImportGraph], viewer: ports.GraphViewer, ) -> None: @@ -23,6 +24,8 @@ def draw_graph( show_cycle_breakers: marks a set of dependencies that, if removed, would make the graph acyclic. sys_path: the sys.path list (or a test double). current_directory: the current working directory. + get_top_level_package: the function to retrieve the top level package name. This will usually be the first part + of the dotted module name (before the first dot), but for namespace packages it should be the 'portion' name. build_graph: the function which builds the graph of the supplied package (pass grimp.build_graph or a test double). viewer: GraphViewer for generating the graph image and opening it. @@ -30,8 +33,8 @@ def draw_graph( # Add current directory to the path, as this doesn't happen automatically. sys_path.insert(0, current_directory) - module = grimp.Module(module_name) - grimp_graph = build_graph(module.package_name) + top_level_package = get_top_level_package(module_name) + grimp_graph = build_graph(top_level_package) dot = _build_dot(grimp_graph, module_name, show_import_totals, show_cycle_breakers) diff --git a/src/impulse/cli.py b/src/impulse/cli.py index 0bbf695..fee13f9 100644 --- a/src/impulse/cli.py +++ b/src/impulse/cli.py @@ -35,6 +35,7 @@ def drawgraph(module_name: str, show_import_totals: bool, show_cycle_breakers: b show_cycle_breakers=show_cycle_breakers, sys_path=sys.path, current_directory=os.getcwd(), + get_top_level_package=adapters.get_top_level_package, build_graph=grimp.build_graph, viewer=adapters.BrowserGraphViewer(), ) diff --git a/tests/unit/application/test_use_cases.py b/tests/unit/application/test_use_cases.py index 8fbac2d..a2123c1 100644 --- a/tests/unit/application/test_use_cases.py +++ b/tests/unit/application/test_use_cases.py @@ -10,6 +10,10 @@ SOME_MODULE = f"{SOME_ROOT_PACKAGE}.foo" +def fake_get_top_level_package_non_namespace(module_name: str) -> str: + return module_name.split(".")[0] + + def build_fake_graph(package_name: str) -> grimp.ImportGraph: graph = grimp.ImportGraph() graph.add_module(package_name) @@ -74,6 +78,7 @@ def test_draw_graph(self): show_cycle_breakers=False, sys_path=sys_path, current_directory=current_directory, + get_top_level_package=fake_get_top_level_package_non_namespace, build_graph=build_fake_graph, viewer=viewer, ) @@ -97,6 +102,32 @@ def test_draw_graph(self): Edge("mypackage.foo.red", "mypackage.foo.blue"), } + def test_draw_graph_calls_top_level_package(self): + def get_top_level_package(module: str) -> str: + return "some.namespace" + + def asserting_build_graph(top_level_package: str) -> grimp.ImportGraph: + assert top_level_package == "some.namespace" + graph = grimp.ImportGraph() + graph.add_module("some.namespace") + graph.add_module("some.namespace.foo") + graph.add_module("some.namespace.foo.blue") + graph.add_module("some.namespace.foo.blue.alpha") + graph.add_module("some.namespace.foo.blue.beta") + return graph + + viewer = SpyGraphViewer() + use_cases.draw_graph( + "some.namespace.foo.blue", + show_import_totals=False, + show_cycle_breakers=False, + sys_path=[], + current_directory="/cwd", + get_top_level_package=get_top_level_package, + build_graph=asserting_build_graph, + viewer=viewer, + ) + def test_draw_graph_show_import_totals(self): viewer = SpyGraphViewer() @@ -106,6 +137,7 @@ def test_draw_graph_show_import_totals(self): show_cycle_breakers=False, sys_path=[], current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, build_graph=build_fake_graph, viewer=viewer, ) @@ -127,6 +159,7 @@ def test_draw_graph_show_cycle_breakers(self): show_cycle_breakers=True, sys_path=[], current_directory="/cwd", + get_top_level_package=fake_get_top_level_package_non_namespace, build_graph=build_fake_graph, viewer=viewer, )