Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@
}

_PTXAS_VERSION_CACHE = None
_CUDA_VERSION_CACHE = None


def _get_max_ptx_version():
"""Detects the maximum PTX version supported by the available ptxas."""
global _PTXAS_VERSION_CACHE
global _PTXAS_VERSION_CACHE, _CUDA_VERSION_CACHE
if _PTXAS_VERSION_CACHE is not None:
return _PTXAS_VERSION_CACHE

Expand All @@ -105,10 +106,11 @@ def _get_max_ptx_version():
match = re.search(r"release (\d+)\.(\d+)", result.stdout)
if match:
major, minor = int(match.group(1)), int(match.group(2))
_CUDA_VERSION_CACHE = (major, minor)
# Map CUDA version to PTX version
if major == 12:
if minor >= 9:
version = 87 # 88 breaks some triton tests
version = 87 # 88 only for sm_121 (GB10/DGX Spark)
elif minor >= 8:
version = 87
elif minor >= 5:
Expand Down Expand Up @@ -166,6 +168,16 @@ def _compile_triton(
# Detect maximum supported PTX version
max_ptx_version = _get_max_ptx_version()

# For CUDA 12.9+, use PTX 88 only for sm_121 (GB10/DGX Spark)
# PTX 88 breaks other configurations
if (
_CUDA_VERSION_CACHE
and _CUDA_VERSION_CACHE[0] == 12
and _CUDA_VERSION_CACHE[1] >= 9
and compute_capability == 121
):
max_ptx_version = 88

# Base options supported by all Triton versions
cuda_options_kwargs = {
"num_warps": num_warps,
Expand Down