From 5200f4e48806698648a74fe143b6fd0be1e6becd Mon Sep 17 00:00:00 2001 From: Weida Hong Date: Fri, 20 Feb 2026 13:13:07 +0000 Subject: [PATCH] Apply new format for ruff tool Current formatting breaks CI. Signed-off-by: Weida Hong --- torchax/checkpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchax/checkpoint.py b/torchax/checkpoint.py index a8d9a40..6caf785 100644 --- a/torchax/checkpoint.py +++ b/torchax/checkpoint.py @@ -35,12 +35,12 @@ def to_jax_array(x): def _to_torch(pytree): - return jax.tree_util.tree_map( - lambda x: torch.from_numpy(np.asarray(x)) - if isinstance(x, (jnp.ndarray, jax.Array)) - else x, - pytree, - ) + def to_torch_tensor(x): + if isinstance(x, (jnp.ndarray, jax.Array)): + return torch.from_numpy(np.asarray(x)) + return x + + return jax.tree_util.tree_map(to_torch_tensor, pytree) def save_checkpoint(state: dict[str, Any], path: str, step: int):