Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions torchax/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def to_jax_array(x):


def _to_torch(pytree):
return jax.tree_util.tree_map(
lambda x: torch.from_numpy(np.asarray(x))
if isinstance(x, (jnp.ndarray, jax.Array))
else x,
pytree,
)
def to_torch_tensor(x):
if isinstance(x, (jnp.ndarray, jax.Array)):
return torch.from_numpy(np.asarray(x))
return x

return jax.tree_util.tree_map(to_torch_tensor, pytree)


def save_checkpoint(state: dict[str, Any], path: str, step: int):
Expand Down
17 changes: 0 additions & 17 deletions torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch.distributed._functional_collectives
from jax import numpy as jnp

from torchax import interop
from torchax.ops import jax_reimplement, mappings, op_base, ops_registry
from torchax.view import View

Expand All @@ -38,22 +37,6 @@ def op(*aten, **kwargs):
def inner(func):
for a in aten:
ops_registry.register_torch_dispatch_op(a, func, **kwargs)
continue

if isinstance(a, torch._ops.OpOverloadPacket):
opname = (
a.default.name() if "default" in a.overloads() else a._qualified_op_name
)
elif isinstance(a, torch._ops.OpOverload):
opname = a.name()
else:
raise RuntimeError(f"oops {a}")

torchfunc = functools.partial(interop.call_jax, func)
# HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor
torch.library.impl(opname, "privateuseone")(
torchfunc if a != torch.ops.aten._to_copy else func
)
return func

return inner
Expand Down