diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index c18a0589..4c1aa6ae 100755 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -109,8 +109,8 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training) param_state = flatten_dict(flux_state.params) for path, val in flatten_dict(transformer_params).items(): sharding = flat_state_shardings[path] - param_state[path].value = max_utils.device_put_replicated(val, sharding) - flux_state = flux_state.replace(params=unflatten_dict(param_state)) + param_state[path] = max_utils.device_put_replicated(val, sharding) + flux_state = flux_state.replace(params=unflatten_dict(param_state)) return flux_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):