Skip to content

Add additional argument to _aten_bmm#79

Open
jrplatin wants to merge 1 commit intogoogle:mainfrom
jrplatin:jacobplatin/mla
Open

Add additional argument to _aten_bmm#79
jrplatin wants to merge 1 commit intogoogle:mainfrom
jrplatin:jacobplatin/mla

Conversation

@jrplatin
Copy link

Fixes this error from running MLA + TorchAX:

(EngineCore_DP0 pid=613611)   File "/mnt/disks/jacobplatin/vllm/vllm/model_executor/layers/attention/mla_attention.py", line 440, in forward
(EngineCore_DP0 pid=613611)     self.forward_impl(
(EngineCore_DP0 pid=613611)   File "/mnt/disks/jacobplatin/vllm/vllm/model_executor/layers/attention/mla_attention.py", line 614, in forward_impl
(EngineCore_DP0 pid=613611)     torch.bmm(mqa_q_nope, self.W_UK_T, out=mqa_ql_nope)
(EngineCore_DP0 pid=613611)   File "/mnt/disks/jacobplatin/torchax/torchax/tensor.py", line 254, in __torch_function__
(EngineCore_DP0 pid=613611)     return func(*args, **(kwargs or {}))
(EngineCore_DP0 pid=613611)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=613611)   File "/mnt/disks/jacobplatin/torchax/torchax/tensor.py", line 277, in __torch_dispatch__
(EngineCore_DP0 pid=613611)     return self.env.dispatch(func, types, args, kwargs)
(EngineCore_DP0 pid=613611)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=613611)   File "/mnt/disks/jacobplatin/torchax/torchax/tensor.py", line 599, in dispatch
(EngineCore_DP0 pid=613611)     res = op.func(*args, **kwargs)
(EngineCore_DP0 pid=613611)           ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=613611) TypeError: _aten_bmm() got an unexpected keyword argument 'out'


@op(torch.ops.aten.bmm)
def _aten_bmm(x, y):
def _aten_bmm(x, y, out=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add an assertion to ensure the out is indeed None?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, there is a helper class OutVariant (and a mapping _out_variant_to_functional), not sure if it can solve your problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants