Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
f94ee08
Update VaceWanPipeline to VaceWanPipeline2_1, fix from_pretrained (#311)
ninatu Jan 20, 2026
bf2d5b2
Update the linter to run on a valid github-hosted-runner (#312)
michelle-yooh Jan 20, 2026
6734125
Fix formatting with pyink (#314)
michelle-yooh Jan 20, 2026
042932e
Add support for TransformerEngine flash attention in WAN (#299)
cpersson-amd Jan 21, 2026
f9b6ff9
Shard video_condition to prevent OOM in WAN 2.2 I2V (#313)
prishajain1 Jan 21, 2026
5a05e75
Fixes mismatched number of output arguments (#315)
martinarroyo Jan 21, 2026
a385c8e
fp8 bug for batch_size setting error (#317)
susanbao Jan 22, 2026
5b0b8cf
Add batch divisibility check for VAE input sharding (#316)
prishajain1 Jan 23, 2026
ad56886
Add GitHub Action to add "pull ready" to a PR upon successful checks …
michelle-yooh Jan 24, 2026
63b5ed7
Enable JIT Compilation of WAN VAE Encoder/Decoder Forward Passes (#320)
prishajain1 Jan 29, 2026
9622341
Add LoRA support for WAN models (#308)
Perseus14 Jan 29, 2026
1e1058a
Move the top level 'tests/' into `src/maxdiffusion/tests/` as a legac…
michelle-yooh Jan 29, 2026
54a7518
feat: add general synthetic data iterator and examples for WAN and FLUX.
meijianhan Jan 20, 2026
206f7db
doc: update README.md with synthetic data iterator usage.
meijianhan Jan 20, 2026
4793049
feat: add loss_scaling_factor for matching the flux loss.
jianhan-amd Jan 29, 2026
26aeeab
feat: add multi-node running scripts. Add dockerfile with IB driver i…
jianhan-amd Jan 29, 2026
0bdcd72
feat: add multi-node scripts without using Slurm.
Jan 30, 2026
44a148b
fix: fix the WAN1.3B model loading bug
Jan 31, 2026
16bedad
enable flux multinode training
cpersson-amd Feb 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* @entrpn
116 changes: 116 additions & 0 deletions .github/workflows/AddPullReady.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: Add Pull Ready Label

on:
workflow_run:
workflows: [Unit Test, Linter]
types:
- completed
pull_request_review:
pull_request_review_comment:
workflow_dispatch:

jobs:
AddPullReady:
permissions:
checks: read
pull-requests: write
runs-on: ubuntu-latest

steps:
- uses: actions/github-script@v7
with:
script: |
const owner = context.repo.owner
const repo = context.repo.repo
let pull_number = -1
if (context.payload.pull_request !== undefined) {
pull_number = context.payload.pull_request.number
} else if (context.payload.workflow_run !== undefined) {
if (context.payload.workflow_run.pull_requests.length === 0) {
console.log("This workflow is NOT running within a PR's context")
process.exit()
}
console.log(context.payload.workflow_run.pull_requests)
pull_number = context.payload.workflow_run.pull_requests[0].number
} else {
console.log("This workflow is running within an invalid context")
process.exit(1)
}
const reviews = await github.rest.pulls.listReviews({
owner,
repo,
pull_number,
})
const decision_query = `
query($owner: String!, $repo: String!, $pull_number: Int!) {
repository(owner: $owner, name: $repo) {
pullRequest(number: $pull_number) {
reviewDecision # Fetches the overall review status
}
}
}
`;
const decision_result = await github.graphql(decision_query, { owner, repo, pull_number });
if (reviews.data.length === 0) {
console.log("Not adding pull ready because the PR is not approved yet.")
process.exit()
}
let is_approved = false
if (decision_result.repository.pullRequest.reviewDecision === "APPROVED") {
is_approved = true
}
if (!is_approved) {
console.log("Not adding pull ready because the PR is not approved yet by sufficient code owners.")
process.exit()
}
const commits = await github.rest.pulls.listCommits({
owner,
repo,
pull_number,
per_page: 100,
})
// Check that the number of commits in the PR is 1.
if (commits.data.length !== 1) {
console.log("Not adding pull ready because the PR has more than one commit. Please squash your commits.")
process.exit(1)
}
const ref = commits.data.slice(-1)[0].sha
const checkRuns = await github.rest.checks.listForRef({
owner,
repo,
ref,
})
if (checkRuns.data.check_runs.length === 0) {
console.log("Not adding pull ready because no check runs are associated with the last commit: " + ref)
process.exit()
}
for (const checkRun of checkRuns.data.check_runs) {
if (checkRun.name.endsWith(context.job)) continue
if (checkRun.conclusion !== "success") {
console.log("Not adding pull ready because " + checkRun.name + " has not passed yet: " + checkRun.html_url)
process.exit()
}
}
console.log("Adding pull ready label because the PR is approved AND all the check runs have passed")
await github.rest.issues.addLabels({
issue_number: pull_number,
labels: ["pull ready"],
owner,
repo,
})
6 changes: 3 additions & 3 deletions .github/workflows/CPUTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ['3.10']
os: [ubuntu-latest]
python-version: ['3.12']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -22,7 +22,7 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install pylint pyink pytype==2024.2.27
pip install pylint pyink==23.10.0 pytype==2024.2.27
# - name: Typecheck the code with pytype
# run: |
# pytype --jobs auto --disable import-error src/maxdiffusion/
Expand Down
14 changes: 14 additions & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,20 @@ After installation completes, run the training script.
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.
- For benchmarking training performance on multiple data dimension input without downloading/re-processing the dataset, the synthetic data iterator is supported.
- Set dataset_type='synthetic' and synthetic_num_samples=null to enable the synthetic data iterator.
- The following overrides on data dimensions are supported:
- synthetic_override_height: 720
- synthetic_override_width: 1280
- synthetic_override_num_frames: 85
- synthetic_override_max_sequence_length: 512
- synthetic_override_text_embed_dim: 4096
- synthetic_override_num_channels_latents: 16
- synthetic_override_vae_scale_factor_spatial: 8
- synthetic_override_vae_scale_factor_temporal: 4

You should eventually see a training run as:

Expand Down
Empty file modified code_style.sh
100644 → 100755
Empty file.
22 changes: 11 additions & 11 deletions end_to_end/tpu/eval_assert.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""
Copyright 2024 Google LLC
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0
https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
Example to run
Expand Down
122 changes: 122 additions & 0 deletions launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env bash




# Parse LOG_PATH argument from command line
LOG_PATH=""
FILTERED_ARGS=()
for arg in "$@"; do
if [[ $arg == LOG_PATH=* ]]; then
LOG_PATH="${arg#*=}"
else
FILTERED_ARGS+=("$arg")
fi
done

# Set default log file if not provided
if [ -z "$LOG_PATH" ]; then
LOG_PATH="$PWD/output/output_$EXP_NAME.log"
fi

export HF_TOKEN=""
export HF_HOME="/app/hf_home/"

# export ROCR_VISIBLE_DEVICES="4,5,6,7"

export MIOPEN_CUSTOM_CACHE_DIR="/app/.cache/miopen/"
export JAX_COMPILATION_CACHE_DIR="/app/.cache/jax/"
export JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES="all"

# export TF_CPP_MIN_LOG_LEVEL=0
# export TF_CPP_MAX_VLOG_LEVEL=3
export JAX_TRACEBACK_FILTERING=off

timestamp=$(date +%Y%m%d-%H%M%S)

export LIBTPU_INIT_ARGS=""

export KERAS_BACKEND="jax"
export JAX_SPMD_MODE="allow_all"
export TOKENIZERS_PARALLELISM="1"

# to skip hard-coded GCS calls
export SKIP_GCS=1

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95
export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=16384

export NVTE_FUSED_ATTN=1
export NVTE_CK_USES_BWD_V3=1 # activates v3, 0 default
export NVTE_CK_USES_FWD_V3=1 # 0 for fsdp tpu
export NVTE_CK_IS_V3_ATOMIC_FP32=0 # default
export NVTE_CK_HOW_V3_BF16_CVT=1 # default
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1

export NCCL_IB_HCA=bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7
export NCCL_SOCKET_IFNAME=ens51f1np1
export NCCL_IB_GID_INDEX=3
export NCCL_PROTO=Simple

export HSA_FORCE_FINE_GRAIN_PCIE=1
export NCCL_MAX_NCHANNELS=16
export RCCL_MSCCL_ENABLE=0
export GPU_MAX_HW_QUEUES=2
export HIP_FORCE_DEV_KERNARG=1
export HSA_NO_SCRATCH_RECLAIM=1
# NCCL flags
export NCCL_DEBUG=INFO #WARN, INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export RCCL_REPLAY_FILE=/shared_nfs/jianhan/slurm_logs-${SCALING_EXP}/cohere-${SLURM_JOB_NUM_NODES}N-8x22B-${SLURM_JOB_ID}-${timestamp}/mixtral_8x-22b_128N_run.bin
export NCCL_PROTO=Simple
export NCCL_IB_TIMEOUT=20
export NCCL_IB_TC=41
export NCCL_IB_SL=0

export GLOO_SOCKET_IFNAME=ens51f1np1
export NCCL_CROSS_NIC=0
export NCCL_CHECKS_DISABLE=1
export NCCL_IB_QPS_PER_CONNECTION=1
##
#OCI said the below env var can improve all-to-all communication:
export NCCL_PXN_DISABLE=0
# export NCCL_P2P_NET_CHUNKSIZE=524288
# export NCCL_MAX_NCHANNELS=16


# UCX flags
export UCX_TLS=tcp,self,sm
export UCX_IB_TRAFFIC_CLASS=41
export UCX_IB_SL=0


HOST_NAME=$(hostname)

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_cublaslt=True
--xla_gpu_graph_level=0 --xla_gpu_autotune_level=5 --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728
--xla_dump_to=${LOG_PATH}/${HOST_NAME}_xla_dump_${timestamp}"

rm -rf /app/.cache/*
python3 setup.py develop

EXP_NAME="WAN_train"
LOG_FILE="$LOG_PATH/output_$HOST_NAME.log"

# python -m src.maxdiffusion.train_flux src/maxdiffusion/configs/base_flux_dev.yml \
python -m src.maxdiffusion.train_wan src/maxdiffusion/configs/base_wan_14b.yml \
run_name="run_$EXP_NAME" output_dir="$PWD/output" \
hardware=gpu \
attention=cudnn_flash_te \
max_train_steps=10 \
dcn_data_parallelism=-1 \
dcn_fsdp_batch_parallelism=1 \
ici_data_parallelism=1 \
ici_fsdp_parallelism=8 \
per_device_batch_size=1 \
enable_ssim=False \
"${FILTERED_ARGS[@]}" |& tee -a "$LOG_FILE"



Loading
Loading