Skip to content

Commit a793517

Browse files
committed
Use SITE_PACKAGES_LIBDIRS_WINDOWS in _find_dll_using_nvidia_bin_dirs, keep find_all_dll_files_via_metadata as fallback
1 parent 6adc349 commit a793517

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

cuda_pathfinder/cuda/pathfinder/_dynamic_libs/find_nvidia_dynamic_lib.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
1212
IS_WINDOWS,
1313
SITE_PACKAGES_LIBDIRS_LINUX,
14+
SITE_PACKAGES_LIBDIRS_WINDOWS,
1415
is_suppressed_dll_file,
1516
)
1617
from cuda.pathfinder._utils.find_site_packages_dll import find_all_dll_files_via_metadata
@@ -64,19 +65,29 @@ def _find_dll_under_dir(dirpath: str, file_wild: str) -> Optional[str]:
6465
return None
6566

6667

67-
def _find_dll_using_nvidia_bin_dirs(libname: str) -> Optional[str]:
68-
libname_lower = libname.lower()
69-
candidates = []
70-
for relname, abs_paths in find_all_dll_files_via_metadata().items():
71-
if is_suppressed_dll_file(relname):
72-
continue
73-
if relname.startswith(libname_lower):
74-
for abs_path in abs_paths:
75-
candidates.append(abs_path)
76-
if candidates:
77-
candidates.sort()
78-
result: str = candidates[0] # help mypy
79-
return result
68+
def _find_dll_using_nvidia_bin_dirs(libname: str, lib_searched_for: str) -> Optional[str]:
69+
rel_dirs = SITE_PACKAGES_LIBDIRS_WINDOWS.get(libname)
70+
if rel_dirs is not None:
71+
# Fast direct access with minimal globbing.
72+
for rel_dir in rel_dirs:
73+
for abs_dir in find_sub_dirs_all_sitepackages(tuple(rel_dir.split(os.path.sep))):
74+
dll_name = _find_dll_under_dir(abs_dir, lib_searched_for)
75+
if dll_name is not None:
76+
return dll_name
77+
else:
78+
# This fallback is relatively slow, but acceptable.
79+
libname_lower = libname.lower()
80+
candidates = []
81+
for relname, abs_paths in find_all_dll_files_via_metadata().items():
82+
if is_suppressed_dll_file(relname):
83+
continue
84+
if relname.startswith(libname_lower):
85+
for abs_path in abs_paths:
86+
candidates.append(abs_path)
87+
if candidates:
88+
candidates.sort()
89+
result: str = candidates[0] # help mypy
90+
return result
8091
return None
8192

8293

@@ -158,7 +169,7 @@ def __init__(self, libname: str):
158169
if IS_WINDOWS:
159170
self.lib_searched_for = f"{libname}*.dll"
160171
if self.abs_path is None:
161-
self.abs_path = _find_dll_using_nvidia_bin_dirs(libname)
172+
self.abs_path = _find_dll_using_nvidia_bin_dirs(libname, self.lib_searched_for)
162173
else:
163174
self.lib_searched_for = f"lib{libname}.so"
164175
if self.abs_path is None:

0 commit comments

Comments
 (0)