From def3ec16bdd395e4d361df9c6fc1e5f273dec7e9 Mon Sep 17 00:00:00 2001 From: Peter Byfield Date: Thu, 1 May 2025 06:33:48 +0200 Subject: [PATCH] Import Scanning: Split raw import extraction from import resolution --- src/grimp/adaptors/importscanner.py | 393 +++++++++------------- tests/unit/adaptors/test_importscanner.py | 86 ++++- 2 files changed, 248 insertions(+), 231 deletions(-) diff --git a/src/grimp/adaptors/importscanner.py b/src/grimp/adaptors/importscanner.py index 5aed4d76..863ebb64 100644 --- a/src/grimp/adaptors/importscanner.py +++ b/src/grimp/adaptors/importscanner.py @@ -1,5 +1,9 @@ +from __future__ import annotations + import ast +import re import logging +from dataclasses import dataclass from typing import Dict, List, Optional, Set, Union from ast import NodeVisitor, Import, ImportFrom, If, Attribute, Name @@ -10,6 +14,16 @@ logger = logging.getLogger(__name__) +_LEADING_DOT_REGEX = re.compile(r"^(\.+)\w") + + +@dataclass(frozen=True) +class _ImportedObject: + name: str + line_number: int + line_contents: str + typechecking_only: bool + class ImportScanner(AbstractImportScanner): def __init__(self, *args, **kwargs) -> None: @@ -32,12 +46,10 @@ def scan_for_imports( """ found_package = self._found_package_for_module(module) module_filename = self._determine_module_filename(module, found_package) - is_package = self._module_is_package(module_filename) module_contents = self._read_module_contents(module_filename) - module_lines = module_contents.splitlines() try: - ast_tree = ast.parse(module_contents) + imported_objects = self._get_raw_imported_objects(module_contents) except SyntaxError as e: raise exceptions.SourceSyntaxError( filename=module_filename, @@ -45,34 +57,50 @@ def scan_for_imports( text=e.text, ) - from_import_parser = _ImportFromNodeParser( - module=module, - found_package=found_package, - found_packages=self.found_packages, - found_packages_by_module=self._found_packages_by_module, - include_external_packages=self.include_external_packages, - is_package=is_package, - ) + is_package = self._module_is_package(module_filename) - import_parser = _ImportNodeParser( - module=module, - found_package=found_package, - found_packages=self.found_packages, - found_packages_by_module=self._found_packages_by_module, - include_external_packages=self.include_external_packages, - is_package=is_package, - ) + imports = set() + for imported_object in imported_objects: + # Filter on `exclude_type_checking_imports`. + if exclude_type_checking_imports and imported_object.typechecking_only: + continue - walker = _TreeWalker( - import_parser=import_parser, - from_import_parser=from_import_parser, - module=module, - module_lines=module_lines, - exclude_type_checking_imports=exclude_type_checking_imports, - ) - walker.visit(ast_tree) + # Resolve relative imports. + imported_object_name = self._get_absolute_imported_object_name( + module=module, is_package=is_package, imported_object_name=imported_object.name + ) + + # Resolve imported module. + imported_module = self._get_internal_module(imported_object_name, modules=self.modules) + if imported_module is None: + # => External import. + + # Filter on `self.include_external_packages`. + if not self.include_external_packages: + continue + + # Distill module. + imported_module = self._distill_external_module( + Module(imported_object_name), found_packages=self.found_packages + ) + if imported_module is None: + continue - return walker.direct_imports + imports.add( + DirectImport( + importer=module, + imported=imported_module, + line_number=imported_object.line_number, + line_contents=imported_object.line_contents, + ) + ) + return imports + + def _found_package_for_module(self, module: Module) -> FoundPackage: + try: + return self._found_packages_by_module[module] + except KeyError: + raise ValueError(f"No found package for module {module}.") def _determine_module_filename(self, module: Module, found_package: FoundPackage) -> str: """ @@ -96,12 +124,6 @@ def _determine_module_filename(self, module: Module, found_package: FoundPackage return candidate_filename raise FileNotFoundError(f"Could not find module {module}.") - def _found_package_for_module(self, module: Module) -> FoundPackage: - try: - return self._found_packages_by_module[module] - except KeyError: - raise ValueError(f"No found package for module {module}.") - def _read_module_contents(self, module_filename: str) -> str: """ Read the file contents of the module. @@ -114,52 +136,59 @@ def _module_is_package(self, module_filename: str) -> bool: """ return self.file_system.split(module_filename)[-1] == "__init__.py" + @staticmethod + def _get_raw_imported_objects(module_contents: str) -> Set[_ImportedObject]: + module_lines = module_contents.splitlines() -class _BaseNodeParser: - """ - Works out from an AST node what the imported modules are. - """ + ast_tree = ast.parse(module_contents) - def __init__( - self, - module: Module, - found_package: FoundPackage, - found_packages: Set[FoundPackage], - found_packages_by_module: Dict[Module, FoundPackage], - is_package: bool, - include_external_packages: bool, - ) -> None: - self.module = module - self.found_package = found_package - self.found_packages = found_packages - self.module_is_package = is_package - self.found_packages_by_module = found_packages_by_module - self.include_external_packages = include_external_packages - - def determine_imported_modules(self, node: ast.AST) -> Set[Module]: - """ - Return the imported modules in the statement. - """ - raise NotImplementedError - - def _is_internal_module(self, module: Module) -> bool: - return module in self.found_packages_by_module + visitor = _TreeVisitor( + module_lines=module_lines, + ) + visitor.visit(ast_tree) + + return visitor.imported_objects + + @staticmethod + def _get_absolute_imported_object_name( + *, module: Module, is_package: bool, imported_object_name: str + ) -> str: + leading_dot_match = _LEADING_DOT_REGEX.match(imported_object_name) + if leading_dot_match is None: + return imported_object_name + + n_leading_dots = len(leading_dot_match.group(1)) + if is_package: + if n_leading_dots == 1: + imported_object_name_base = module.name + else: + imported_object_name_base = ".".join( + module.name.split(".")[: -(n_leading_dots - 1)] + ) + else: + imported_object_name_base = ".".join(module.name.split(".")[:-n_leading_dots]) + return imported_object_name_base + "." + imported_object_name[n_leading_dots:] - def _is_internal_object(self, full_object_name: str) -> bool: - # Build a Module that may or may not exist. - candidate_module = Module(full_object_name) - if self._is_internal_module(candidate_module): - return True + @staticmethod + def _get_internal_module(object_name: str, *, modules: Set[Module]) -> Optional[Module]: + candidate_module = Module(object_name) + if candidate_module in modules: + return candidate_module - # Also check the parent. In the case of non-module objects, this may be an internal module. try: - parent = candidate_module.parent + candidate_module = candidate_module.parent except ValueError: - return False + return None else: - return self._is_internal_module(parent) + if candidate_module in modules: + return candidate_module + else: + return None - def _distill_external_module(self, module: Module) -> Optional[Module]: + @staticmethod + def _distill_external_module( + module: Module, *, found_packages: Set[FoundPackage] + ) -> Optional[Module]: """ Given a module that we already know is external, turn it into a module to add to the graph. @@ -176,13 +205,13 @@ def _distill_external_module(self, module: Module) -> Optional[Module]: """ # If it's a module that is a parent of one of the internal packages, return None # as it doesn't make sense and is probably an import of a namespace package. - if any(Module(package.name).is_descendant_of(module) for package in self.found_packages): + if any(Module(package.name).is_descendant_of(module) for package in found_packages): return None # If it shares a namespace with an internal module, get the shallowest component that does # not clash with an internal module namespace. candidate_portions: Set[Module] = set() - for found_package in sorted(self.found_packages, key=lambda p: p.name, reverse=True): + for found_package in sorted(found_packages, key=lambda p: p.name, reverse=True): root_module = Module(found_package.name) if root_module.is_descendant_of(module.root): ( @@ -210,181 +239,85 @@ def _distill_external_module(self, module: Module) -> Optional[Module]: return module.root -class _ImportNodeParser(_BaseNodeParser): +class _TreeVisitor(NodeVisitor): + def __init__( + self, + module_lines: List[str], + ) -> None: + self.import_parser = _ImportNodeParser() + self.from_import_parser = _ImportFromNodeParser() + self.module_lines = module_lines + + self.imported_objects: Set[_ImportedObject] = set() + self.typechecking_only = False + + super().__init__() + + def visit_Import(self, node: Import) -> None: + self._parse_imported_objects_from_node(node, self.import_parser) + + def visit_ImportFrom(self, node: ImportFrom) -> None: + self._parse_imported_objects_from_node(node, self.from_import_parser) + + def visit_If(self, node: If) -> None: + if (isinstance(node.test, Name) and node.test.id == "TYPE_CHECKING") or ( + isinstance(node.test, Attribute) and node.test.attr == "TYPE_CHECKING" + ): + self.typechecking_only = True + super().generic_visit(node) + self.typechecking_only = False + else: + super().generic_visit(node) + + def _parse_imported_objects_from_node( + self, + node: Union[Import, ImportFrom], + parser: Union[_ImportNodeParser, _ImportFromNodeParser], + ) -> None: + for imported_object in parser.determine_imported_objects(node): + self.imported_objects.add( + _ImportedObject( + name=imported_object, + line_number=node.lineno, + line_contents=self.module_lines[node.lineno - 1].strip(), + typechecking_only=self.typechecking_only, + ) + ) + + +class _ImportNodeParser: """ Parser for statements in the form 'import x'. """ node_class = ast.Import - def determine_imported_modules(self, node: ast.AST) -> Set[Module]: - imported_modules: Set[Module] = set() - + def determine_imported_objects(self, node: ast.AST) -> Set[str]: + imported_objects: Set[str] = set() assert isinstance(node, self.node_class) # For type checker. for alias in node.names: - imported_module = self._module_from_name(alias.name) - if imported_module: - imported_modules.add(imported_module) + imported_object = alias.name + imported_objects.add(imported_object) + return imported_objects - return imported_modules - def _module_from_name(self, module_name: str) -> Optional[Module]: - module = Module(module_name) - if self._is_internal_module(module): - return module - else: - if self.include_external_packages: - return self._distill_external_module(module) - else: - return None - - -class _ImportFromNodeParser(_BaseNodeParser): +class _ImportFromNodeParser: """ Parser for statements in the form 'from x import ...'. """ node_class = ast.ImportFrom - def determine_imported_modules(self, node: ast.AST) -> Set[Module]: - imported_modules: Set[Module] = set() + def determine_imported_objects(self, node: ast.AST) -> Set[str]: + imported_objects: Set[str] = set() assert isinstance(node, self.node_class) # For type checker. assert isinstance(node.level, int) # For type checker. - if node.level == 0: - # Absolute import. - # Let the type checker know we expect node.module to be set here. - assert isinstance(node.module, str) - node_module = Module(node.module) - if not self._is_internal_module(node_module): - if self.include_external_packages: - # Return the top level package of the external module. - external_modules = set() - for alias in node.names: - full_object_name = ".".join([node.module, alias.name]) - untrimmed_module = Module(full_object_name) - external_module = self._distill_external_module(untrimmed_module) - if external_module: - external_modules.add(external_module) - return external_modules - else: - return set() - # Don't include imports of modules outside this package. - - module_base = node.module - elif node.level >= 1: - # Relative import. The level corresponds to how high up the tree it goes; - # for example 'from ... import foo' would be level 3. - importing_module_components = self.module.name.split(".") - # TODO: handle level that is too high. - # Trim the base module by the number of levels. - if self.module_is_package: - # If the scanned module an __init__.py file, we don't want - # to go up an extra level. - number_of_levels_to_trim_by = node.level - 1 - else: - number_of_levels_to_trim_by = node.level - - if number_of_levels_to_trim_by: - module_base = ".".join(importing_module_components[:-number_of_levels_to_trim_by]) - else: - module_base = ".".join(importing_module_components) - if node.module: - module_base = ".".join([module_base, node.module]) - - # node.names corresponds to 'a', 'b' and 'c' in 'from x import a, b, c'. for alias in node.names: - full_object_name = ".".join([module_base, alias.name]) - imported_module = self._module_from_object_name(full_object_name) - if imported_module: - imported_modules.add(imported_module) - - return imported_modules - - def _trim_to_internal_module(self, untrimmed_module: Module) -> Module: - """ - Raises FileNotFoundError if it could not find a valid module. - """ - if self._is_internal_module(untrimmed_module): - return untrimmed_module - else: - # The module isn't in the internal modules. This is because it's something *within* - # a module (e.g. a function): the result of something like 'from .subpackage - # import my_function'. So we trim the components back to the module. - components = untrimmed_module.name.split(".")[:-1] - trimmed_module = Module(".".join(components)) - - if self._is_internal_module(trimmed_module): - return trimmed_module - else: - raise FileNotFoundError() - - def _module_from_object_name(self, full_object_name: str) -> Optional[Module]: - if self._is_internal_object(full_object_name): - untrimmed_module = Module(full_object_name) - try: - imported_module = self._trim_to_internal_module(untrimmed_module=untrimmed_module) - except FileNotFoundError: - logger.warning( - f"Could not find {full_object_name} when scanning {self.module}. " - "This may be due to a missing __init__.py file in the parent package." - ) + if node.module is None: + imported_object = f"{'.' * node.level}{alias.name}" else: - return imported_module - else: - untrimmed_module = Module(full_object_name) - if self.include_external_packages: - return self._distill_external_module(untrimmed_module) - return None + imported_object = f"{'.' * node.level}{node.module}.{alias.name}" + imported_objects.add(imported_object) - -class _TreeWalker(NodeVisitor): - def __init__( - self, - import_parser: _ImportNodeParser, - from_import_parser: _ImportFromNodeParser, - module: Module, - module_lines: List[str], - *, - exclude_type_checking_imports: bool, - ) -> None: - self.module = module - self.module_lines = module_lines - self.exclude_type_checking_imports = exclude_type_checking_imports - self.direct_imports: Set[DirectImport] = set() - self.import_parser = import_parser - self.from_import_parser = from_import_parser - super().__init__() - - def visit_Import(self, node: Import) -> None: - self._parse_direct_imports_from_node(node, self.import_parser) - - def visit_ImportFrom(self, node: ImportFrom) -> None: - self._parse_direct_imports_from_node(node, self.from_import_parser) - - def visit_If(self, node: If) -> None: - if self.exclude_type_checking_imports: - # Case for "if TYPE_CHECKING:" - if isinstance(node.test, Name) and node.test.id == "TYPE_CHECKING": - return # Skip parsing - - # Case for "if xxx.TYPE_CHECKING:" - if isinstance(node.test, Attribute) and node.test.attr == "TYPE_CHECKING": - return # Skip parsing - - super().generic_visit(node) - - def _parse_direct_imports_from_node( - self, - node: Union[Import, ImportFrom], - parser: Union[_ImportNodeParser, _ImportFromNodeParser], - ) -> None: - for imported in parser.determine_imported_modules(node): - self.direct_imports.add( - DirectImport( - importer=self.module, - imported=imported, - line_number=node.lineno, - line_contents=self.module_lines[node.lineno - 1].strip(), - ) - ) + return imported_objects diff --git a/tests/unit/adaptors/test_importscanner.py b/tests/unit/adaptors/test_importscanner.py index 9c3fb12f..f01480ee 100644 --- a/tests/unit/adaptors/test_importscanner.py +++ b/tests/unit/adaptors/test_importscanner.py @@ -2,7 +2,7 @@ import pytest # type: ignore -from grimp.adaptors.importscanner import ImportScanner +from grimp.adaptors.importscanner import ImportScanner, _ImportedObject from grimp.application.ports.modulefinder import FoundPackage, ModuleFile from grimp.domain.valueobjects import DirectImport, Module from tests.adaptors.filesystem import FakeFileSystem @@ -842,3 +842,87 @@ def test_exclude_type_checking_imports( def _modules_to_module_files(modules: Set[Module]) -> Set[ModuleFile]: some_mtime = 100933.4 return {ModuleFile(module=module, mtime=some_mtime) for module in modules} + + +def test_get_raw_imports(): + module_contents = """\ +import a +if TYPE_CHECKING: + import b +from c import d +from .e import f +from . import g +from .. import h +from i import * +""" + + raw_imported_objects = ImportScanner._get_raw_imported_objects(module_contents) + + assert raw_imported_objects == { + _ImportedObject( + name="a", + line_number=1, + line_contents="import a", + typechecking_only=False, + ), + _ImportedObject( + name="b", + line_number=3, + line_contents="import b", + typechecking_only=True, + ), + _ImportedObject( + name="c.d", + line_number=4, + line_contents="from c import d", + typechecking_only=False, + ), + _ImportedObject( + name=".e.f", + line_number=5, + line_contents="from .e import f", + typechecking_only=False, + ), + _ImportedObject( + name=".g", + line_number=6, + line_contents="from . import g", + typechecking_only=False, + ), + _ImportedObject( + name="..h", + line_number=7, + line_contents="from .. import h", + typechecking_only=False, + ), + _ImportedObject( + name="i.*", + line_number=8, + line_contents="from i import *", + typechecking_only=False, + ), + } + + +@pytest.mark.parametrize( + "is_package,imported_object_name,expected_absolute_imported_object_name", + [ + [True, "a.b", "a.b"], + [True, ".a.b", "foo.bar.baz.a.b"], + [True, "..a.b", "foo.bar.a.b"], + [False, "a.b", "a.b"], + [False, ".a.b", "foo.bar.a.b"], + [False, "..a.b", "foo.a.b"], + ], +) +def test_get_absolute_imported_object_name( + is_package, imported_object_name, expected_absolute_imported_object_name +): + assert ( + ImportScanner._get_absolute_imported_object_name( + module=Module("foo.bar.baz"), + is_package=is_package, + imported_object_name=imported_object_name, + ) + == expected_absolute_imported_object_name + )