From 16bedadd5a6a6230160b07c0c5b9a6009b489294 Mon Sep 17 00:00:00 2001 From: Carl Persson Date: Thu, 5 Feb 2026 15:31:08 +0000 Subject: [PATCH 1/2] enable flux multinode training --- src/maxdiffusion/checkpointing/flux_checkpointer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) mode change 100644 => 100755 src/maxdiffusion/checkpointing/flux_checkpointer.py diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py old mode 100644 new mode 100755 index 78ad000b..c18a0589 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -20,6 +20,7 @@ import json import jax from jax.sharding import Mesh +from flax.traverse_util import flatten_dict, unflatten_dict import orbax.checkpoint as ocp import grain.python as grain from maxdiffusion import ( @@ -103,8 +104,13 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training) training=is_training, ) if not self.config.train_new_flux: - flux_state = flux_state.replace(params=transformer_params) - flux_state = jax.device_put(flux_state, state_mesh_shardings) + with self.mesh: + flat_state_shardings = flatten_dict(state_mesh_shardings.params) + param_state = flatten_dict(flux_state.params) + for path, val in flatten_dict(transformer_params).items(): + sharding = flat_state_shardings[path] + param_state[path].value = max_utils.device_put_replicated(val, sharding) + flux_state = flux_state.replace(params=unflatten_dict(param_state)) return flux_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): From ca03dd445edffa588b1a27e8645733a195500776 Mon Sep 17 00:00:00 2001 From: Carl Persson Date: Fri, 6 Feb 2026 11:09:42 +0000 Subject: [PATCH 2/2] fix parameter loading --- src/maxdiffusion/checkpointing/flux_checkpointer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py index c18a0589..4c1aa6ae 100755 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -109,8 +109,8 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training) param_state = flatten_dict(flux_state.params) for path, val in flatten_dict(transformer_params).items(): sharding = flat_state_shardings[path] - param_state[path].value = max_utils.device_put_replicated(val, sharding) - flux_state = flux_state.replace(params=unflatten_dict(param_state)) + param_state[path] = max_utils.device_put_replicated(val, sharding) + flux_state = flux_state.replace(params=unflatten_dict(param_state)) return flux_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):