diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py old mode 100644 new mode 100755 index 78ad000b..c18a0589 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -20,6 +20,7 @@ import json import jax from jax.sharding import Mesh +from flax.traverse_util import flatten_dict, unflatten_dict import orbax.checkpoint as ocp import grain.python as grain from maxdiffusion import ( @@ -103,8 +104,13 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training) training=is_training, ) if not self.config.train_new_flux: - flux_state = flux_state.replace(params=transformer_params) - flux_state = jax.device_put(flux_state, state_mesh_shardings) + with self.mesh: + flat_state_shardings = flatten_dict(state_mesh_shardings.params) + param_state = flatten_dict(flux_state.params) + for path, val in flatten_dict(transformer_params).items(): + sharding = flat_state_shardings[path] + param_state[path].value = max_utils.device_put_replicated(val, sharding) + flux_state = flux_state.replace(params=unflatten_dict(param_state)) return flux_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):