Skip to content

Out of shared memory on blackwell architectures with embedding size 128 #244

@psturmfels

Description

@psturmfels

Describe the bug
The documentation (https://docs.nvidia.com/cuda/cuequivariance/api/generated/cuequivariance_torch.triangle_attention.html) states that:

Triangle attention kernel supports: all hidden_dim<=32 and divisible by 4 for tf32/fp32, and for all hidden_dim<=128 and divisible by 8 for bf16/fp16

I'm running cuequivariance_torch==0.8.1, cuequivariance-ops-cu13==0.8.1 and torch==2.10.0+cu130 with CUDA 13.0, and running on an RTX 6000 PRO Blackwell series GPU. When attempting to run the kernel on an input with a hidden_dim=128, I get a CUDA error (see below for full trace).

To Reproduce

import math

import torch
from cuequivariance_torch import triangle_attention


def main():
    device = torch.device("cuda")
    # Set up dimensions
    batch_size, seq_len, num_heads, hidden_dim = 2, 512, 4, 128
    # Create input tensors on GPU with float16 precision
    q = torch.randn(
        batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
    )
    k = torch.randn(
        batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
    )
    v = torch.randn(
        batch_size, seq_len, num_heads, seq_len, hidden_dim, device=device, dtype=torch.float16, requires_grad=True
    )
    bias = torch.randn(
        batch_size, 1, num_heads, seq_len, seq_len, device=device, dtype=torch.float16, requires_grad=True
    )
    # Create optional mask
    mask = torch.rand(batch_size, seq_len, 1, 1, seq_len, device=device) < 0.5
    # Calculate scale
    scale = 1 / math.sqrt(hidden_dim)
    # Forward pass
    output, lse, max_val = triangle_attention(q=q, k=k, v=v, bias=bias, mask=mask, scale=scale, return_aux=True)
    print(output.shape)
    # Create gradient tensor and perform backward pass
    grad_out = torch.randn_like(output)
    output.backward(grad_out)
    # Access gradients
    print(q.grad.shape)
    print(k.grad.shape)
    print(v.grad.shape)
    print(bias.grad.shape)


if __name__ == "__main__":
    main()

Expected behavior
I believe, according to the documentation, that this input should have run properly without errors.

Screenshots
The full stack trace is below.

/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py:165: UserWarning: Non-SM100f kernel expects bias to be float32 so it's going to be cast to torch.float32. Check if you can change your code for maximum performance.
  warnings.warn(
torch.Size([2, 512, 4, 512, 96])
Traceback (most recent call last):
  File "/net/home/pascal/sandbox-pascal/scripts/misc/test_kernel.py", line 42, in <module>
    main()
  File "/net/home/pascal/sandbox-pascal/scripts/misc/test_kernel.py", line 33, in main
    output.backward(grad_out)
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
    torch.autograd.backward(
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
    _engine_run_backward(
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 78, in backward
    result = info._backward_fn(ctx, *grads)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py", line 432, in _backward
    d_q, d_k, d_v, dbias = torch.ops.cuequivariance.triangle_attention_bwd(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_ops.py", line 1209, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 112, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/autograd.py", line 41, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_ops.py", line 826, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)  # type: ignore[return-value]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 347, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/torch/_library/custom_ops.py", line 382, in wrapped_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/net/home/pascal/sandbox-pascal/.pixi/envs/default/lib/python3.12/site-packages/cuequivariance_ops_torch/triangle_attention.py", line 292, in _
    ops.triangle_attention_bwd(
RuntimeError: CUDA error: "invalid argument" at bwd_fmha.cu:193
Failed call: cudaFuncSetAttribute( (const void*)*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)

GPU HW/SW(please complete the following information):

  • CUDA toolkit versions:
  • cuda-bindings==13.0.3 ; sys_platform == 'linux'
  • nvidia-cuda-nvrtc==13.0.88 ; sys_platform == 'linux'
  • nvidia-cuda-runtime==13.0.96 ; sys_platform == 'linux'
  • nvidia-cuda-cupti==13.0.85 ; sys_platform == 'linux'
  • torch==2.10.0+cu130
  • Driver version: 13.0
  • full name of GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition

Any advice would be appreciated. Let me know if I can add more context.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions