diff --git a/src/grimp/application/usecases.py b/src/grimp/application/usecases.py index bb43d7a8..e19bb6e8 100644 --- a/src/grimp/application/usecases.py +++ b/src/grimp/application/usecases.py @@ -15,13 +15,15 @@ from ..domain.valueobjects import DirectImport, Module from .config import settings -N_CPUS = multiprocessing.cpu_count() - class NotSupplied: pass +# This is an arbitrary number, but setting it too low slows down our functional tests considerably. +MIN_NUMBER_OF_MODULES_TO_SCAN_USING_MULTIPROCESSING = 50 + + def build_graph( package_name, *additional_package_names, @@ -209,32 +211,66 @@ def _scan_imports( include_external_packages: bool, exclude_type_checking_imports: bool, ) -> Dict[ModuleFile, Set[DirectImport]]: - import_scanner: AbstractImportScanner = settings.IMPORT_SCANNER_CLASS( - file_system=file_system, - found_packages=found_packages, - include_external_packages=include_external_packages, + chunks = _create_chunks(module_files) + return _scan_chunks( + chunks, + file_system, + found_packages, + include_external_packages, + exclude_type_checking_imports, ) - imports_by_module_file: Dict[ModuleFile, Set[DirectImport]] = {} - n_chunks = min(N_CPUS, len(module_files)) - chunks = _create_chunks(list(module_files), n_chunks=n_chunks) - with multiprocessing.Pool(n_chunks) as pool: - import_scanning_jobs = pool.starmap( - _scan_chunk, - [(import_scanner, exclude_type_checking_imports, chunk) for chunk in chunks], - ) - for chunk_imports_by_module_file in import_scanning_jobs: - imports_by_module_file.update(chunk_imports_by_module_file) +def _create_chunks(module_files: Collection[ModuleFile]) -> tuple[tuple[ModuleFile, ...], ...]: + """ + Split the module files into chunks, each to be worked on by a separate OS process. + """ + module_files_tuple = tuple(module_files) - return imports_by_module_file + number_of_module_files = len(module_files_tuple) + n_chunks = _decide_number_of_of_processes(number_of_module_files) + chunk_size = math.ceil(number_of_module_files / n_chunks) + + return tuple( + module_files_tuple[i * chunk_size : (i + 1) * chunk_size] for i in range(n_chunks) + ) -def _create_chunks( - module_files: Sequence[ModuleFile], *, n_chunks: int -) -> Iterable[Iterable[ModuleFile]]: - chunk_size = math.ceil(len(module_files) / n_chunks) - return [module_files[i * chunk_size : (i + 1) * chunk_size] for i in range(n_chunks)] +def _decide_number_of_of_processes(number_of_module_files: int) -> int: + if number_of_module_files < MIN_NUMBER_OF_MODULES_TO_SCAN_USING_MULTIPROCESSING: + # Don't incur the overhead of multiprocessing. + return 1 + return min(multiprocessing.cpu_count(), number_of_module_files) + + +def _scan_chunks( + chunks: Collection[Collection[ModuleFile]], + file_system: AbstractFileSystem, + found_packages: Set[FoundPackage], + include_external_packages: bool, + exclude_type_checking_imports: bool, +) -> Dict[ModuleFile, Set[DirectImport]]: + import_scanner: AbstractImportScanner = settings.IMPORT_SCANNER_CLASS( + file_system=file_system, + found_packages=found_packages, + include_external_packages=include_external_packages, + ) + + number_of_processes = len(chunks) + if number_of_processes == 1: + # No need to spawn a process if there's only one chunk. + [chunk] = chunks + return _scan_chunk(import_scanner, exclude_type_checking_imports, chunk) + else: + with multiprocessing.Pool(number_of_processes) as pool: + imports_by_module_file: Dict[ModuleFile, Set[DirectImport]] = {} + import_scanning_jobs = pool.starmap( + _scan_chunk, + [(import_scanner, exclude_type_checking_imports, chunk) for chunk in chunks], + ) + for chunk_imports_by_module_file in import_scanning_jobs: + imports_by_module_file.update(chunk_imports_by_module_file) + return imports_by_module_file def _scan_chunk( diff --git a/tests/functional/test_build_and_use_graph.py b/tests/functional/test_build_and_use_graph.py index ed2ce877..76b728e7 100644 --- a/tests/functional/test_build_and_use_graph.py +++ b/tests/functional/test_build_and_use_graph.py @@ -1,6 +1,8 @@ from grimp import build_graph from typing import Set, Tuple, Optional import pytest +from unittest.mock import patch +from grimp.application import usecases """ For ease of reference, these are the imports of all the files: @@ -53,6 +55,33 @@ def test_modules(): } +@patch.object(usecases, "MIN_NUMBER_OF_MODULES_TO_SCAN_USING_MULTIPROCESSING", 0) +def test_modules_multiprocessing(): + """ + This test runs relatively slowly, but it's important we cover the multiprocessing code. + """ + graph = build_graph("testpackage", cache_dir=None) + + assert graph.modules == { + "testpackage", + "testpackage.one", + "testpackage.one.alpha", + "testpackage.one.beta", + "testpackage.one.gamma", + "testpackage.one.delta", + "testpackage.one.delta.blue", + "testpackage.two", + "testpackage.two.alpha", + "testpackage.two.beta", + "testpackage.two.gamma", + "testpackage.utils", + "testpackage.three", + "testpackage.three.beta", + "testpackage.three.gamma", + "testpackage.three.alpha", + } + + def test_add_module(): graph = build_graph("testpackage", cache_dir=None) number_of_modules = len(graph.modules)