diff --git a/torchax/checkpoint.py b/torchax/checkpoint.py index a8d9a40..6caf785 100644 --- a/torchax/checkpoint.py +++ b/torchax/checkpoint.py @@ -35,12 +35,12 @@ def to_jax_array(x): def _to_torch(pytree): - return jax.tree_util.tree_map( - lambda x: torch.from_numpy(np.asarray(x)) - if isinstance(x, (jnp.ndarray, jax.Array)) - else x, - pytree, - ) + def to_torch_tensor(x): + if isinstance(x, (jnp.ndarray, jax.Array)): + return torch.from_numpy(np.asarray(x)) + return x + + return jax.tree_util.tree_map(to_torch_tensor, pytree) def save_checkpoint(state: dict[str, Any], path: str, step: int):