Skip to content

Conversation

@javak87
Copy link
Contributor

@javak87 javak87 commented Dec 29, 2025

Description

The max operation needed for flash_attn_varlen_func introduces an unnecessary CUDA synchronization. This sync can also prevent the module from being compiled. This PR moves the max operation as far up the call stack as possible to avoid the sync at the attention-module level.

Issue Number

Closes #1531

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@javak87 javak87 changed the title Move the CUDA sync caused by the max operation to a higher level. Move the CUDA sync caused by the max operation to a higher level Dec 29, 2025
@clessig
Copy link
Collaborator

clessig commented Dec 29, 2025

What is the performance impact of the synchronization, i.e. how much faster is the code with the max operation moved up?

@javak87
Copy link
Contributor Author

javak87 commented Dec 30, 2025

What is the performance impact of the synchronization, i.e. how much faster is the code with the max operation moved up?

The problem isn’t an explicit cuda synchronization. I’m trying to compile the model incrementally, starting from the lowest level (the attention module). At this level, to compile the class that uses flash_attn_varlen_func, max_lens must be an integer (per the official FlashAttention repository). However, putting an int conversion inside the forward method triggers this compilation error:

 File "/p/project1/hclimrep/kasravi1/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

from user code:
   File "/p/project1/hclimrep/kasravi1/WeatherGenerator/src/weathergen/model/attention.py", line 89, in forward
    max_lens = int(x_lens.max())

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Using without int is also trigger this error:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function flash_attn._flash_attn_varlen_forward(*(FakeTensor(..., device='cuda:2', size=(3715, 16, 64), dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), FakeTensor(..., device='cuda:2', size=(3715, 16, 64), dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), FakeTensor(..., device='cuda:2', size=(3715, 16, 64), dtype=torch.bfloat16, grad_fn=<ViewBackward0>), FakeTensor(..., device='cuda:2', size=(6145,), dtype=torch.int32), FakeTensor(..., device='cuda:2', size=(6145,), dtype=torch.int32), FakeTensor(..., device='cuda:2', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:2', size=(), dtype=torch.int64), 0.1, 0.125), **{'causal': False, 'window_size_left': -1, 'window_size_right': -1, 'softcap': 0.0, 'alibi_slopes': None, 'return_softmax': False, 'block_table': None}):
flash_attn::_flash_attn_varlen_forward() Expected a value of type 'int' for argument 'max_seqlen_q' but instead found type 'FakeTensor'.
Position: 5
Value: FakeTensor(..., device='cuda:2', size=(), dtype=torch.int64)
Declaration: flash_attn::_flash_attn_varlen_forward(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left=-1, SymInt window_size_right=-1, float softcap=0., Tensor? alibi_slopes=None, bool return_softmax=False, Tensor? block_table=None, Tensor? leftpad_k=None, Tensor? seqused_k=None, bool zero_tensors=False) -> (Tensor, Tensor, Tensor, Tensor)
Cast error details: Unable to cast Python instance of type <class 'torch._subclasses.fake_tensor.FakeTensor'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

Therefore, we should avoid computing x_lens.max()—or any other operation that takes a maximum—inside the forward method.

@tjhunter
Copy link
Collaborator

I can see how int(x_lens.max()) is going to cause a fully sync (it moves the value to CPU). What happens if you perform a cast instead to something such as int32 ? this is what the c++ side wants:

               int max_seqlen_q,
               const int max_seqlen_k,

@clessig
Copy link
Collaborator

clessig commented Jan 5, 2026

What is the performance impact of the synchronization, i.e. how much faster is the code with the max operation moved up?

The problem isn’t an explicit cuda synchronization. I’m trying to compile the model incrementally, starting from the lowest level (the attention module). At this level, to compile the class that uses flash_attn_varlen_func, max_lens must be an integer (per the official FlashAttention repository). However, putting an int conversion inside the forward method triggers this compilation error:

 File "/p/project1/hclimrep/kasravi1/WeatherGenerator/.venv/lib/python3.12/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item

from user code:
   File "/p/project1/hclimrep/kasravi1/WeatherGenerator/src/weathergen/model/attention.py", line 89, in forward
    max_lens = int(x_lens.max())

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Using without int is also trigger this error:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function flash_attn._flash_attn_varlen_forward(*(FakeTensor(..., device='cuda:2', size=(3715, 16, 64), dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), FakeTensor(..., device='cuda:2', size=(3715, 16, 64), dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>), FakeTensor(..., device='cuda:2', size=(3715, 16, 64), dtype=torch.bfloat16, grad_fn=<ViewBackward0>), FakeTensor(..., device='cuda:2', size=(6145,), dtype=torch.int32), FakeTensor(..., device='cuda:2', size=(6145,), dtype=torch.int32), FakeTensor(..., device='cuda:2', size=(), dtype=torch.int64), FakeTensor(..., device='cuda:2', size=(), dtype=torch.int64), 0.1, 0.125), **{'causal': False, 'window_size_left': -1, 'window_size_right': -1, 'softcap': 0.0, 'alibi_slopes': None, 'return_softmax': False, 'block_table': None}):
flash_attn::_flash_attn_varlen_forward() Expected a value of type 'int' for argument 'max_seqlen_q' but instead found type 'FakeTensor'.
Position: 5
Value: FakeTensor(..., device='cuda:2', size=(), dtype=torch.int64)
Declaration: flash_attn::_flash_attn_varlen_forward(Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left=-1, SymInt window_size_right=-1, float softcap=0., Tensor? alibi_slopes=None, bool return_softmax=False, Tensor? block_table=None, Tensor? leftpad_k=None, Tensor? seqused_k=None, bool zero_tensors=False) -> (Tensor, Tensor, Tensor, Tensor)
Cast error details: Unable to cast Python instance of type <class 'torch._subclasses.fake_tensor.FakeTensor'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)

Therefore, we should avoid computing x_lens.max()—or any other operation that takes a maximum—inside the forward method.

But flash_attn is not comilable because of the C++ code dependence (there's in principle a way to wrap flash_attn so that it is compilable, but unconvinced that it will work). The issue won't be solved by moving the max computation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Max calculation inside of forward creates CUDA sync

3 participants