From d0f899b55ed5f1aeb281362326f1094d1af0c105 Mon Sep 17 00:00:00 2001 From: hari sadasivan Date: Mon, 9 Feb 2026 13:37:45 -0800 Subject: [PATCH] Restrict PTX 88 to sm_121 for CUDA 12.9+ PTX version 88 breaks other configurations with CUDA 12.9. This change ensures PTX 88 is only used for GB10/sm_121 (DGX Spark), while other architectures use the safer PTX 87. Co-Authored-By: Claude Sonnet 4.5 --- .../cuequivariance_jax/triangle/triton_utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py b/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py index 0437e30..42d411d 100644 --- a/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py +++ b/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py @@ -87,11 +87,12 @@ } _PTXAS_VERSION_CACHE = None +_CUDA_VERSION_CACHE = None def _get_max_ptx_version(): """Detects the maximum PTX version supported by the available ptxas.""" - global _PTXAS_VERSION_CACHE + global _PTXAS_VERSION_CACHE, _CUDA_VERSION_CACHE if _PTXAS_VERSION_CACHE is not None: return _PTXAS_VERSION_CACHE @@ -105,10 +106,11 @@ def _get_max_ptx_version(): match = re.search(r"release (\d+)\.(\d+)", result.stdout) if match: major, minor = int(match.group(1)), int(match.group(2)) + _CUDA_VERSION_CACHE = (major, minor) # Map CUDA version to PTX version if major == 12: if minor >= 9: - version = 87 # 88 breaks some triton tests + version = 87 # 88 only for sm_121 (GB10/DGX Spark) elif minor >= 8: version = 87 elif minor >= 5: @@ -166,6 +168,16 @@ def _compile_triton( # Detect maximum supported PTX version max_ptx_version = _get_max_ptx_version() + # For CUDA 12.9+, use PTX 88 only for sm_121 (GB10/DGX Spark) + # PTX 88 breaks other configurations + if ( + _CUDA_VERSION_CACHE + and _CUDA_VERSION_CACHE[0] == 12 + and _CUDA_VERSION_CACHE[1] >= 9 + and compute_capability == 121 + ): + max_ptx_version = 88 + # Base options supported by all Triton versions cuda_options_kwargs = { "num_warps": num_warps,