modularize with src/diffuse_nnx directory#2
Open
DBraun wants to merge 2 commits intowillisma:mainfrom
Open
Conversation
Owner
|
Thanks for the amazing work! The overall refactorization looks good to me! Have you tested if there's any effect of this modularization on training / evaluation code? |
Author
|
Thanks. I haven't tested it beyond basic installation and importing. I would suggest creating a GitHub Action workflow to automate some of this. I can add to this branch if you'd like. |
Owner
|
Yes that would be great! |
6dab321 to
97a0e15
Compare
Author
|
@willisma There's now a GitHub action that runs the tests. However, there are some failures. It looks like Flax 0.10.7 gets installed. If you were to update to 0.12.0 then there would definitely be some more errors. Do you think you could take it from here since you're more familiar with the test code? |
Owner
|
Thanks for the amazing work! Yes let me take over! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Thanks for this repository. I did this refactor to make it have more of a standard module organization. Then after installing, in other projects one can do
from diffuse_nnx.networks.transformers.lightning_ddt_nnx import LightningDDTetc. Here is aCLAUDE.mdin case you'd like to add it tooCLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
Overview
DiffuseNNX is a JAX/NNX library for diffusion and flow matching generative models. It implements DiT (Diffusion Transformer) and variants for ImageNet training with various sampling strategies. Built on JAX and Flax NNX (with PyTorch-like syntax).
Project Structure
The codebase uses the standard src/ layout:
Development Commands
Installation
Testing
Training
Reference training scripts are in
commands/.Architecture
Core Design Pattern
The codebase separates concerns into three main layers:
diffuse_nnx.interfaces): Diffusion/flow matching algorithms (SiT, EDM, MeanFlow, REPA)diffuse_nnx.networks): Model architectures (DiT, LightningDiT, VAE encoders/decoders)diffuse_nnx.samplers): Sampling strategies (Euler, Heun, Euler-Maruyama)Training Flow
__main__.pyparses flags and creates workdir (GCS bucket or local filesystem)configs/*.py(uses ml_collections.ConfigDict)get_trainers()to select trainer module (currently onlydit_imagenet)Config System (
diffuse_nnx.configs)Configs use
ml_collections.ConfigDictwith presets fromcommon_specs.py:_imagenet_data_presets: Dataset paths, image sizes, batch sizes_imagenet_encoder_presets: Encoder types (RGB, StabilityVAE, etc.)_dit_network_presets: Network architectures (hidden_size, depth, num_heads)Important: Update
_*_data_presetsentries incommon_specs.pyto point to your ImageNet data paths and FID statistics before training.NNX Training Pattern
Training uses Flax NNX with in-place updates:
nnx.GraphDefandnnx.Statesplit/merge patternoptimizer.modelaccesses the networkema_graphandema_stateoptimizer.update(grads)andema.update(model)Distributed Training (
diffuse_nnx.utils.sharding_utils)docs/utils/fsdp_in_jax_nnx.ipynbfor FSDP tutorialImport Conventions
All imports use the full package name with
diffuse_nnx.prefix:Important Notes
Platform & Dependencies
pip install -e .[gpu]instead of[tpu]Environment Setup
Required environment variables (store in
.env, never commit):WANDB_API_KEY: For logging (get from https://wandb.ai/authorize)WANDB_ENTITY: Your W&B team/spaceGOOGLE_APPLICATION_CREDENTIALS: Path to GCP credentials JSONGCS_BUCKET: Google Cloud Storage bucket nameRun
gcloud auth application-default loginbefore using GCS.Flax Transformers Deprecation
Some parts still depend on deprecated Flax
transformerslibrary. Reproduction/replacement in progress.Code Style (from README)
snake_casefor modules/functions,CamelCasefor classesabsl.loggingfor logs,ml_collections.ConfigDictfor configstrainers/dit_imagenet.pyfor import grouping exampleTest Naming Convention
Test files must follow
*_tests.pypattern fortests/runner.pyto discover them.