diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 314dee1eb3..bbf72db966 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -76,10 +76,41 @@ from .comm import gen_trtllm_comm_module as gen_trtllm_comm_module from .comm import gen_vllm_comm_module as gen_vllm_comm_module from .comm import gen_nvshmem_module as gen_nvshmem_module +from typing import Optional + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), ( + f"Unexpected filename: {filename} for library {lib_name}" + ) + return path cuda_lib_path = os.environ.get( "CUDA_LIB_PATH", "/usr/local/cuda/targets/x86_64-linux/lib/" ) -if os.path.exists(f"{cuda_lib_path}/libcudart.so.12"): +process_cudart_path = find_loaded_library("libcudart") +if process_cudart_path is not None: + ctypes.CDLL(process_cudart_path, mode=ctypes.RTLD_GLOBAL) +elif os.path.exists(f"{cuda_lib_path}/libcudart.so.12"): ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL)