diff --git a/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py b/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py index 0437e30..42d411d 100644 --- a/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py +++ b/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py @@ -87,11 +87,12 @@ } _PTXAS_VERSION_CACHE = None +_CUDA_VERSION_CACHE = None def _get_max_ptx_version(): """Detects the maximum PTX version supported by the available ptxas.""" - global _PTXAS_VERSION_CACHE + global _PTXAS_VERSION_CACHE, _CUDA_VERSION_CACHE if _PTXAS_VERSION_CACHE is not None: return _PTXAS_VERSION_CACHE @@ -105,10 +106,11 @@ def _get_max_ptx_version(): match = re.search(r"release (\d+)\.(\d+)", result.stdout) if match: major, minor = int(match.group(1)), int(match.group(2)) + _CUDA_VERSION_CACHE = (major, minor) # Map CUDA version to PTX version if major == 12: if minor >= 9: - version = 87 # 88 breaks some triton tests + version = 87 # 88 only for sm_121 (GB10/DGX Spark) elif minor >= 8: version = 87 elif minor >= 5: @@ -166,6 +168,16 @@ def _compile_triton( # Detect maximum supported PTX version max_ptx_version = _get_max_ptx_version() + # For CUDA 12.9+, use PTX 88 only for sm_121 (GB10/DGX Spark) + # PTX 88 breaks other configurations + if ( + _CUDA_VERSION_CACHE + and _CUDA_VERSION_CACHE[0] == 12 + and _CUDA_VERSION_CACHE[1] >= 9 + and compute_capability == 121 + ): + max_ptx_version = 88 + # Base options supported by all Triton versions cuda_options_kwargs = { "num_warps": num_warps,