diff --git a/torchax/checkpoint.py b/torchax/checkpoint.py index a8d9a40..6caf785 100644 --- a/torchax/checkpoint.py +++ b/torchax/checkpoint.py @@ -35,12 +35,12 @@ def to_jax_array(x): def _to_torch(pytree): - return jax.tree_util.tree_map( - lambda x: torch.from_numpy(np.asarray(x)) - if isinstance(x, (jnp.ndarray, jax.Array)) - else x, - pytree, - ) + def to_torch_tensor(x): + if isinstance(x, (jnp.ndarray, jax.Array)): + return torch.from_numpy(np.asarray(x)) + return x + + return jax.tree_util.tree_map(to_torch_tensor, pytree) def save_checkpoint(state: dict[str, Any], path: str, step: int): diff --git a/torchax/ops/jaten.py b/torchax/ops/jaten.py index 3f0886e..e2d72d2 100644 --- a/torchax/ops/jaten.py +++ b/torchax/ops/jaten.py @@ -25,7 +25,6 @@ import torch.distributed._functional_collectives from jax import numpy as jnp -from torchax import interop from torchax.ops import jax_reimplement, mappings, op_base, ops_registry from torchax.view import View @@ -38,22 +37,6 @@ def op(*aten, **kwargs): def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) - continue - - if isinstance(a, torch._ops.OpOverloadPacket): - opname = ( - a.default.name() if "default" in a.overloads() else a._qualified_op_name - ) - elif isinstance(a, torch._ops.OpOverload): - opname = a.name() - else: - raise RuntimeError(f"oops {a}") - - torchfunc = functools.partial(interop.call_jax, func) - # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor - torch.library.impl(opname, "privateuseone")( - torchfunc if a != torch.ops.aten._to_copy else func - ) return func return inner