From c4f74e9111b78d3e28a061c5bea8494bd29315f1 Mon Sep 17 00:00:00 2001 From: Jacob Platin Date: Sun, 22 Feb 2026 06:22:17 +0000 Subject: [PATCH] Initial commit --- torchax/ops/jaten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchax/ops/jaten.py b/torchax/ops/jaten.py index 3f0886e..a5a42e6 100644 --- a/torchax/ops/jaten.py +++ b/torchax/ops/jaten.py @@ -542,7 +542,7 @@ def _aten_dist(input, other, p=2): @op(torch.ops.aten.bmm) -def _aten_bmm(x, y): +def _aten_bmm(x, y, out=None): res = x @ y return res # return jnp.einsum('bnm,bmk->bnk', x, y)