From 5200f4e48806698648a74fe143b6fd0be1e6becd Mon Sep 17 00:00:00 2001 From: Weida Hong Date: Fri, 20 Feb 2026 13:13:07 +0000 Subject: [PATCH 1/2] Apply new format for ruff tool Current formatting breaks CI. Signed-off-by: Weida Hong --- torchax/checkpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchax/checkpoint.py b/torchax/checkpoint.py index a8d9a40..6caf785 100644 --- a/torchax/checkpoint.py +++ b/torchax/checkpoint.py @@ -35,12 +35,12 @@ def to_jax_array(x): def _to_torch(pytree): - return jax.tree_util.tree_map( - lambda x: torch.from_numpy(np.asarray(x)) - if isinstance(x, (jnp.ndarray, jax.Array)) - else x, - pytree, - ) + def to_torch_tensor(x): + if isinstance(x, (jnp.ndarray, jax.Array)): + return torch.from_numpy(np.asarray(x)) + return x + + return jax.tree_util.tree_map(to_torch_tensor, pytree) def save_checkpoint(state: dict[str, Any], path: str, step: int): From a118151ac884b48b24bca9b28eb311d599f2ec38 Mon Sep 17 00:00:00 2001 From: Weida Hong Date: Sat, 21 Feb 2026 13:30:04 +0000 Subject: [PATCH 2/2] Remove unused code in op registration Signed-off-by: Weida Hong --- torchax/ops/jaten.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/torchax/ops/jaten.py b/torchax/ops/jaten.py index 3f0886e..e2d72d2 100644 --- a/torchax/ops/jaten.py +++ b/torchax/ops/jaten.py @@ -25,7 +25,6 @@ import torch.distributed._functional_collectives from jax import numpy as jnp -from torchax import interop from torchax.ops import jax_reimplement, mappings, op_base, ops_registry from torchax.view import View @@ -38,22 +37,6 @@ def op(*aten, **kwargs): def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) - continue - - if isinstance(a, torch._ops.OpOverloadPacket): - opname = ( - a.default.name() if "default" in a.overloads() else a._qualified_op_name - ) - elif isinstance(a, torch._ops.OpOverload): - opname = a.name() - else: - raise RuntimeError(f"oops {a}") - - torchfunc = functools.partial(interop.call_jax, func) - # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor - torch.library.impl(opname, "privateuseone")( - torchfunc if a != torch.ops.aten._to_copy else func - ) return func return inner