-
Notifications
You must be signed in to change notification settings - Fork 26
Description
Describe the bug
The documentation (https://docs.nvidia.com/cuda/cuequivariance/api/generated/cuequivariance_torch.triangle_attention.html) states that:
Triangle attention kernel supports: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16
I'm running cuequivariance_torch==0.8.1, cuequivariance-ops-cu13==0.8.1 and torch==2.10.0+cu130 with CUDA 13.0, and running on an RTX 6000 PRO Blackwell series GPU. When attempting to run the kernel on an input with a hidden_dim=128, I get a CUDA error (see below for full trace).
To Reproduce
import math
import torch
from cuequivariance_torch import triangle_attention
def main():
device = torch.device("cuda")
# Set up dimensions
batch_size, seq_len, num_heads, hidden_dim = 2, 512, 4, 128
# Create input tensors on GPU with float16 precision
q = torch.randn(
batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
)
k = torch.randn(
batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
)
v = torch.randn(
batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
)
bias = torch.randn(
batch_size, 1, num_heads, seq_len, seq_len, device=device, dtype=torch.float16, requires_grad=True
)
# Create optional mask
mask = torch.rand(batch_size, seq_len, 1, 1, seq_len, device=device) < 0.5
# Calculate scale
scale = 1 / math.sqrt(hidden_dim)
# Forward pass
output, lse, max_val = triangle_attention(q=q, k=k, v=v, bias=bias, mask=mask, scale=scale, return_aux=True)
print(output.shape)
# Create gradient tensor and perform backward pass
grad_out = torch.randn_like(output)
output.backward(grad_out)
# Access gradients
print(q.grad.shape)
print(k.grad.shape)
print(v.grad.shape)
print(bias.grad.shape)
if __name__ == "__main__":
main()
Expected behavior
I believe, according to the documentation, that this input should have run properly without errors.
Screenshots
The full stack trace is below.
/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py:165: UserWarning: Non-SM100f kernel expects bias to be float32 so it's going to be cast to torch.float32. Check if you can change your code for maximum performance.
warnings.warn(
torch.Size([2, 512, 4, 512, 96])
Traceback (most recent call last):
File "/net/home/pascal/sandbox-pascal/scripts/misc/test_kernel.py", line 42, in <module>
main()
File "/net/home/pascal/sandbox-pascal/scripts/misc/test_kernel.py", line 33, in main
output.backward(grad_out)
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
torch.autograd.backward(
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
_engine_run_backward(
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 78, in backward
result = info._backward_fn(ctx, *grads)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py", line 432, in _backward
d_q, d_k, d_v, dbias = torch.ops.cuequivariance.triangle_attention_bwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 112, in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 41, in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_ops.py", line 826, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 347, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 382, in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py", line 292, in _
ops.triangle_attention_bwd(
RuntimeError: CUDA error: "invalid argument" at bwd_fmha.cu:193
Failed call: cudaFuncSetAttribute( (const void*)*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
GPU HW/SW(please complete the following information):
- CUDA toolkit versions:
- cuda-bindings==13.0.3 ; sys_platform == 'linux'
- nvidia-cuda-nvrtc==13.0.88 ; sys_platform == 'linux'
- nvidia-cuda-runtime==13.0.96 ; sys_platform == 'linux'
- nvidia-cuda-cupti==13.0.85 ; sys_platform == 'linux'
- torch==2.10.0+cu130
- Driver version: 13.0
- full name of GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition
Any advice would be appreciated. Let me know if I can add more context.