diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..9cbc7119 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @entrpn diff --git a/.github/workflows/AddPullReady.yml b/.github/workflows/AddPullReady.yml new file mode 100644 index 00000000..6122e1ed --- /dev/null +++ b/.github/workflows/AddPullReady.yml @@ -0,0 +1,116 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Add Pull Ready Label + +on: + workflow_run: + workflows: [Unit Test, Linter] + types: + - completed + pull_request_review: + pull_request_review_comment: + workflow_dispatch: + +jobs: + AddPullReady: + permissions: + checks: read + pull-requests: write + runs-on: ubuntu-latest + + steps: + - uses: actions/github-script@v7 + with: + script: | + const owner = context.repo.owner + const repo = context.repo.repo + let pull_number = -1 + if (context.payload.pull_request !== undefined) { + pull_number = context.payload.pull_request.number + } else if (context.payload.workflow_run !== undefined) { + if (context.payload.workflow_run.pull_requests.length === 0) { + console.log("This workflow is NOT running within a PR's context") + process.exit() + } + console.log(context.payload.workflow_run.pull_requests) + pull_number = context.payload.workflow_run.pull_requests[0].number + } else { + console.log("This workflow is running within an invalid context") + process.exit(1) + } + const reviews = await github.rest.pulls.listReviews({ + owner, + repo, + pull_number, + }) + const decision_query = ` + query($owner: String!, $repo: String!, $pull_number: Int!) { + repository(owner: $owner, name: $repo) { + pullRequest(number: $pull_number) { + reviewDecision # Fetches the overall review status + } + } + } + `; + const decision_result = await github.graphql(decision_query, { owner, repo, pull_number }); + + if (reviews.data.length === 0) { + console.log("Not adding pull ready because the PR is not approved yet.") + process.exit() + } + let is_approved = false + if (decision_result.repository.pullRequest.reviewDecision === "APPROVED") { + is_approved = true + } + if (!is_approved) { + console.log("Not adding pull ready because the PR is not approved yet by sufficient code owners.") + process.exit() + } + + const commits = await github.rest.pulls.listCommits({ + owner, + repo, + pull_number, + per_page: 100, + }) + // Check that the number of commits in the PR is 1. + if (commits.data.length !== 1) { + console.log("Not adding pull ready because the PR has more than one commit. Please squash your commits.") + process.exit(1) + } + const ref = commits.data.slice(-1)[0].sha + const checkRuns = await github.rest.checks.listForRef({ + owner, + repo, + ref, + }) + if (checkRuns.data.check_runs.length === 0) { + console.log("Not adding pull ready because no check runs are associated with the last commit: " + ref) + process.exit() + } + for (const checkRun of checkRuns.data.check_runs) { + if (checkRun.name.endsWith(context.job)) continue + if (checkRun.conclusion !== "success") { + console.log("Not adding pull ready because " + checkRun.name + " has not passed yet: " + checkRun.html_url) + process.exit() + } + } + console.log("Adding pull ready label because the PR is approved AND all the check runs have passed") + await github.rest.issues.addLabels({ + issue_number: pull_number, + labels: ["pull ready"], + owner, + repo, + }) diff --git a/.github/workflows/CPUTests.yml b/.github/workflows/CPUTests.yml index aa2ecfed..df0bccf2 100644 --- a/.github/workflows/CPUTests.yml +++ b/.github/workflows/CPUTests.yml @@ -11,8 +11,8 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04] - python-version: ['3.10'] + os: [ubuntu-latest] + python-version: ['3.12'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -22,7 +22,7 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - pip install pylint pyink pytype==2024.2.27 + pip install pylint pyink==23.10.0 pytype==2024.2.27 # - name: Typecheck the code with pytype # run: | # pytype --jobs auto --disable import-error src/maxdiffusion/ diff --git a/README.md b/README.md old mode 100644 new mode 100755 index e6584d23..859f6c22 --- a/README.md +++ b/README.md @@ -255,6 +255,20 @@ After installation completes, run the training script. - In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism. - You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism. - For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now. + - For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance. + - Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes. + - ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis. + - For benchmarking training performance on multiple data dimension input without downloading/re-processing the dataset, the synthetic data iterator is supported. + - Set dataset_type='synthetic' and synthetic_num_samples=null to enable the synthetic data iterator. + - The following overrides on data dimensions are supported: + - synthetic_override_height: 720 + - synthetic_override_width: 1280 + - synthetic_override_num_frames: 85 + - synthetic_override_max_sequence_length: 512 + - synthetic_override_text_embed_dim: 4096 + - synthetic_override_num_channels_latents: 16 + - synthetic_override_vae_scale_factor_spatial: 8 + - synthetic_override_vae_scale_factor_temporal: 4 You should eventually see a training run as: diff --git a/code_style.sh b/code_style.sh old mode 100644 new mode 100755 diff --git a/end_to_end/tpu/eval_assert.py b/end_to_end/tpu/eval_assert.py index 20fd0d8a..33f4c0b2 100644 --- a/end_to_end/tpu/eval_assert.py +++ b/end_to_end/tpu/eval_assert.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Example to run diff --git a/launch.sh b/launch.sh new file mode 100755 index 00000000..672b905d --- /dev/null +++ b/launch.sh @@ -0,0 +1,122 @@ +#!/usr/bin/env bash + + + + +# Parse LOG_PATH argument from command line +LOG_PATH="" +FILTERED_ARGS=() +for arg in "$@"; do + if [[ $arg == LOG_PATH=* ]]; then + LOG_PATH="${arg#*=}" + else + FILTERED_ARGS+=("$arg") + fi +done + +# Set default log file if not provided +if [ -z "$LOG_PATH" ]; then + LOG_PATH="$PWD/output/output_$EXP_NAME.log" +fi + +export HF_TOKEN="" +export HF_HOME="/app/hf_home/" + +# export ROCR_VISIBLE_DEVICES="4,5,6,7" + +export MIOPEN_CUSTOM_CACHE_DIR="/app/.cache/miopen/" +export JAX_COMPILATION_CACHE_DIR="/app/.cache/jax/" +export JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES="all" + +# export TF_CPP_MIN_LOG_LEVEL=0 +# export TF_CPP_MAX_VLOG_LEVEL=3 +export JAX_TRACEBACK_FILTERING=off + +timestamp=$(date +%Y%m%d-%H%M%S) + +export LIBTPU_INIT_ARGS="" + +export KERAS_BACKEND="jax" +export JAX_SPMD_MODE="allow_all" +export TOKENIZERS_PARALLELISM="1" + +# to skip hard-coded GCS calls +export SKIP_GCS=1 + +export XLA_PYTHON_CLIENT_MEM_FRACTION=0.95 +export TF_CUDNN_WORKSPACE_LIMIT_IN_MB=16384 + +export NVTE_FUSED_ATTN=1 +export NVTE_CK_USES_BWD_V3=1 # activates v3, 0 default +export NVTE_CK_USES_FWD_V3=1 # 0 for fsdp tpu +export NVTE_CK_IS_V3_ATOMIC_FP32=0 # default +export NVTE_CK_HOW_V3_BF16_CVT=1 # default +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 + +export NCCL_IB_HCA=bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7 +export NCCL_SOCKET_IFNAME=ens51f1np1 +export NCCL_IB_GID_INDEX=3 +export NCCL_PROTO=Simple + +export HSA_FORCE_FINE_GRAIN_PCIE=1 +export NCCL_MAX_NCHANNELS=16 +export RCCL_MSCCL_ENABLE=0 +export GPU_MAX_HW_QUEUES=2 +export HIP_FORCE_DEV_KERNARG=1 +export HSA_NO_SCRATCH_RECLAIM=1 +# NCCL flags +export NCCL_DEBUG=INFO #WARN, INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export RCCL_REPLAY_FILE=/shared_nfs/jianhan/slurm_logs-${SCALING_EXP}/cohere-${SLURM_JOB_NUM_NODES}N-8x22B-${SLURM_JOB_ID}-${timestamp}/mixtral_8x-22b_128N_run.bin +export NCCL_PROTO=Simple +export NCCL_IB_TIMEOUT=20 +export NCCL_IB_TC=41 +export NCCL_IB_SL=0 + +export GLOO_SOCKET_IFNAME=ens51f1np1 +export NCCL_CROSS_NIC=0 +export NCCL_CHECKS_DISABLE=1 +export NCCL_IB_QPS_PER_CONNECTION=1 +## +#OCI said the below env var can improve all-to-all communication: +export NCCL_PXN_DISABLE=0 +# export NCCL_P2P_NET_CHUNKSIZE=524288 +# export NCCL_MAX_NCHANNELS=16 + + +# UCX flags +export UCX_TLS=tcp,self,sm +export UCX_IB_TRAFFIC_CLASS=41 +export UCX_IB_SL=0 + + +HOST_NAME=$(hostname) + +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_cublaslt=True + --xla_gpu_graph_level=0 --xla_gpu_autotune_level=5 --xla_gpu_enable_reduce_scatter_combine_by_dim=false + --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_all_gather_combine_threshold_bytes=134217728 + --xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 + --xla_dump_to=${LOG_PATH}/${HOST_NAME}_xla_dump_${timestamp}" + +rm -rf /app/.cache/* +python3 setup.py develop + +EXP_NAME="WAN_train" +LOG_FILE="$LOG_PATH/output_$HOST_NAME.log" + +# python -m src.maxdiffusion.train_flux src/maxdiffusion/configs/base_flux_dev.yml \ +python -m src.maxdiffusion.train_wan src/maxdiffusion/configs/base_wan_14b.yml \ + run_name="run_$EXP_NAME" output_dir="$PWD/output" \ + hardware=gpu \ + attention=cudnn_flash_te \ + max_train_steps=10 \ + dcn_data_parallelism=-1 \ + dcn_fsdp_batch_parallelism=1 \ + ici_data_parallelism=1 \ + ici_fsdp_parallelism=8 \ + per_device_batch_size=1 \ + enable_ssim=False \ + "${FILTERED_ARGS[@]}" |& tee -a "$LOG_FILE" + + + diff --git a/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile b/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile new file mode 100755 index 00000000..56920d09 --- /dev/null +++ b/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile @@ -0,0 +1,108 @@ +# CONTEXT {'gpu_vendor': 'AMD', 'guest_os': 'UBUNTU'} +############################################################################### +# +# MIT License +# +# Copyright (c) Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################# + +ARG BASE_DOCKER=rocm/pyt-megatron-lm-jax-nightly-private:jax_rocm7.1_jax_0.7.1_20251215 +# ARG BASE_DOCKER=rocm/jax-training:maxtext-v25.11 +FROM $BASE_DOCKER +USER root +ENV WORKSPACE_DIR=/workspace +RUN mkdir -p $WORKSPACE_DIR +WORKDIR $WORKSPACE_DIR + +# Environment variables +ENV HIP_FORCE_DEV_KERNARG=1 +ARG MAX_JOBS_ARG=192 +ENV MAX_JOBS=${MAX_JOBS_ARG} + +# Argument to check current GPU arch +ARG MAD_SYSTEM_GPU_ARCHITECTURE +ENV HIP_ARCHITECTURES=${MAD_SYSTEM_GPU_ARCHITECTURE} +RUN echo HIP_ARCHITECTURES = ${HIP_ARCHITECTURES} + +# Install necessary system dependencies (if any, e.g., git, build-essential) +RUN apt-get update && apt-get install -y --no-install-recommends \ + git && \ + apt-get clean && rm -rf /var/lib/apt/lists/* && \ + python3 -m pip install --upgrade pip && \ + pip install "huggingface_hub[cli]" + +RUN pip install \ + scikit-image \ + torch==2.8.0 \ + torchvision==0.24.0 \ + torchcodec \ + imageio-ffmpeg \ + --break-system-packages --find-links https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0/ + +RUN pip install \ + flax==0.11.2 \ + tokamax \ + einshape \ + typeguard==2.13.3 \ + qwix==0.1.5 --no-deps + +#Download MaxDiffusion +# RUN cd ${WORKSPACE_DIR} && \ +# git clone https://github.com/AI-Hypercomputer/maxdiffusion.git && \ +# cd maxdiffusion && \ +# git reset --hard "07b4d29c4a9bbdaafa501299275dcb15b5365034" && \ +# python3 setup.py develop +# RUN cd ${WORKSPACE_DIR} && \ +# git clone https://github.com/cpersson-amd/maxdiffusion.git && \ +# cd maxdiffusion && \ +# git reset --hard "07b4d29c4a9bbdaafa501299275dcb15b5365034" && \ +# python3 setup.py develop + +# Display installed packages for verification +RUN pip list + +# libaries for IB fabric +RUN apt-get update +RUN apt-get install -y libelf-dev unzip +RUN apt-get install -y gcc make libtool autoconf librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils perftest ethtool libibverbs-dev rdma-core strace libibmad5 libibnetdisc5 ibverbs-providers libibumad-dev libibumad3 libibverbs1 libnl-3-dev libnl-route-3-dev + +WORKDIR $WORKSPACE_DIR/ + +# The drivers should upgrade with each release and match the host version +RUN wget https://docs.broadcom.com/docs-and-downloads/ethernet-network-adapters/NXE/Thor2/GCA1/bcm5760x_230.2.52.0a.zip +RUN unzip bcm5760x_230.2.52.0a.zip +RUN cd bcm5760x_230.2.52.0a/drivers_linux/bnxt_rocelib/ && \ + results=$(find -name "libbnxt*.tar.gz") && tar -xf $results && \ + untar_dir=$(find . -maxdepth 1 -type d -name "libbnxt*" ! -name "*.tar.gz" | head -n 1) && \ + cd $untar_dir && sh autogen.sh && ./configure && make && \ + find /usr/lib64/ /usr/lib -name "libbnxt_re-rdmav*.so" -exec mv {} {}.inbox \; && \ + make install all && sudo sh -c "echo /usr/local/lib >> /etc/ld.so.conf" && \ + sudo ldconfig && \ + cp -f bnxt_re.driver /etc/libibverbs.d/ && \ + find . -name "*.so" -exec md5sum {} \; && \ + BUILT_MD5SUM=$(find . -name "libbnxt_re-rdmav*.so" -exec md5sum {} \; | cut -d " " -f 1) && \ + echo -e "\n\nmd5sum of the built libbnxt_re is $BUILT_MD5SUM" + +RUN ibv_devices + + + diff --git a/multi_node/wan_multinode_train.sbatch b/multi_node/wan_multinode_train.sbatch new file mode 100755 index 00000000..03544ff0 --- /dev/null +++ b/multi_node/wan_multinode_train.sbatch @@ -0,0 +1,192 @@ +#!/bin/bash +#SBATCH --job-name=wan_multinode # Job name +#SBATCH --nodes=2 # Number of nodes (adjust as needed) +#SBATCH --ntasks-per-node=1 # One task per node +#SBATCH --gres=gpu:8 # Assuming 8 GPUs per node, adjust if different +#SBATCH --mem=0 # Use all available memory +#SBATCH --time=48:00:00 # Max runtime (48 hours) +#SBATCH --output=logs/wan_train_%j.out # Standard output (%j = job ID) +#SBATCH --error=logs/wan_train_%j.err # Standard error +#SBATCH --exclusive # Exclusive node access +#SBATCH --account=amd-silo-tgr # Your account/project name +#SBATCH --partition=amd-silo-tgr # Partition/queue name +##SBATCH --nodelist=chi-mi300x-[002,005,013,019] + + +timestamp=$(date +%Y%m%d-%H%M%S) + +MULTI_NODES_LOG_DIR="/home/jianhmei/code/github/multi_nodes" + +# Docker image to use +IMAGE_TAG="jianhan-wan-multinode-train:v1" +SHARE_DOCKERFILE_PATH="${MULTI_NODES_LOG_DIR}/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile" + +# shared code base path +SHARED_CODE_BASE_PATH="/home/jianhmei/code/github/upstream_maxdiff/maxdiffusion" +MAXDIFFUSION_DIR_IN_DOCKER="/app/maxdiffusion" + +# scaling test exp NAME +export EXP_NAME=WAN_1_3B_FSDP8_2N_${timestamp} + +export LOG_DIR="${MULTI_NODES_LOG_DIR}/slurm_logs-${EXP_NAME}/${SLURM_JOB_NUM_NODES}-${SLURM_JOB_ID}" +export SRUN_OUTPUT="${LOG_DIR}/node_%N_rank_%t.log" +# Create log directory +mkdir -p $LOG_DIR + +export HOSTNAME=$(hostname) +OUTPUT_DIR="${MULTI_NODES_LOG_DIR}/output/output-${EXP_NAME}/${HOSTNAME}-${SLURM_JOB_NUM_NODES}-${SLURM_JOB_ID}" +OUTPUT_DIR_IN_DOCKER="/app/output-${EXP_NAME}/${HOSTNAME}-${SLURM_JOB_NUM_NODES}-${SLURM_JOB_ID}" +# Create output directory if it doesn't exist +mkdir -p ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR}/configs/models + +# Registry credentials +export REGISTRY_USERNAME="" +export REGISTRY_TOKEN="" + +# Define log file paths with timestamp +export HOST_OUTPUT_LOG="${LOG_DIR}/host_output.out" +export HOST_ERROR_LOG="${LOG_DIR}/host_output.err" + +# Redirect stdout and stderr to the new log files +exec > >(tee -a "${HOST_OUTPUT_LOG}") 2> >(tee -a "${HOST_ERROR_LOG}" >&2) + + +# Get the coordinator IP (first node in the allocation) +COORDINATOR_IP=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +echo "List of assigned hostnames:$(scontrol show hostnames $SLURM_JOB_NODELIST)" +NNODES=$SLURM_JOB_NUM_NODES + +# Add error checking +if [ -z "$COORDINATOR_IP" ]; then + echo "Error: Could not determine coordinator IP" + exit 1 +fi + +# container cleanup: +srun bash -c ' + # Clean up any existing containers (with proper permission handling) + if docker ps -q | grep -q .; then + echo "Stopping existing containers on $(hostname)" + docker stop $(docker ps -q) + else + echo "No running containers to stop on $(hostname)" + fi + + # Remove stopped containers + if docker ps -aq | grep -q .; then + echo "Removing stopped containers on $(hostname)" + docker rm $(docker ps -aq) + else + echo "No containers to remove on $(hostname)" + fi + + docker image rm -f '"$IMAGE_TAG"' +' + +# # Also update the Docker login section to use sudo: +# echo "Logging into Docker Hub on all nodes" +# srun --export=ALL,REGISTRY_USERNAME,REGISTRY_TOKEN bash -c ' +# echo "Node $(hostname): Logging into Docker Hub" +# echo '"$REGISTRY_TOKEN"' | docker login docker.io -u '"$REGISTRY_USERNAME"' --password-stdin + +# if [ $? -ne 0 ]; then +# echo "Failed to login to Docker Hub on $(hostname)!" +# exit 1 +# else +# echo "Successfully logged into Docker Hub on $(hostname)" +# fi +# ' + + +# Update the image pull section to use sudo: +echo "Building the container image on all nodes" +srun bash -c ' + MAX_RETRIES=5 + INITIAL_DELAY=30 # seconds + MAX_DELAY=1800 # seconds + RETRY_COUNT=0 + + while true; do + echo "Node $(hostname): Pulling image '"$IMAGE_TAG"' (Attempt $((RETRY_COUNT + 1)))" + echo "Attempting to pull image: $IMAGE_TO_PULL (Attempt $((RETRY_COUNT + 1)))" + docker build --tag '"$IMAGE_TAG"' --file '"$SHARE_DOCKERFILE_PATH"' . + + if [ $? -eq 0 ]; then + echo "Image built successfully." + break + else + RETRY_COUNT=$((RETRY_COUNT + 1)) + if [ "$RETRY_COUNT" -ge "$MAX_RETRIES" ]; then + echo "Failed to build image after $MAX_RETRIES attempts. Exiting." + exit 1 + fi + + CURRENT_DELAY=$((INITIAL_DELAY * (2 ** (RETRY_COUNT - 1)))) + if [ "$CURRENT_DELAY" -gt "$MAX_DELAY" ]; then + CURRENT_DELAY="$MAX_DELAY" + fi + + echo "Pull failed. Retrying in $CURRENT_DELAY seconds..." + sleep "$CURRENT_DELAY" + fi + done + + echo "Done with image building" +' + +srun --nodes=$SLURM_JOB_NUM_NODES --ntasks=$SLURM_JOB_NUM_NODES bash -c ' + echo "Cleaning containers on $(hostname)..." + docker ps -aq | xargs -r docker rm -f +' + +# Add a small delay to ensure cleanup is complete +sleep 5 + +# Modified srun command with proper variable expansion + +# --mount type=bind,source='"${OUTPUT_DIR}"',target='"${OUTPUT_DIR_IN_DOCKER}"' \ +# --mount type=bind,source='"${SHARED_CODE_BASE_PATH}"',target='"${MAXDIFFUSION_DIR_IN_DOCKER}"' \ +echo "Starting the training on all nodes" +srun --output=$SRUN_OUTPUT --nodes=$SLURM_JOB_NUM_NODES --ntasks=$SLURM_JOB_NUM_NODES \ + --export=ALL \ + bash -c ' + NODE_RANK=$SLURM_PROCID + NNODES=$SLURM_JOB_NUM_NODES + docker run --rm --privileged --network host \ + --cap-add=IPC_LOCK \ + --volume /dev/infiniband:/dev/infiniband \ + --tmpfs /dev/shm:size=200G \ + --volume '"${OUTPUT_DIR}"':'"${OUTPUT_DIR_IN_DOCKER}"' \ + --volume '"${SHARED_CODE_BASE_PATH}"':'"${MAXDIFFUSION_DIR_IN_DOCKER}"' \ + -e JAX_COORDINATOR_IP='"${COORDINATOR_IP}"' \ + -e JAX_COORDINATOR_PORT=12345 \ + -e NNODES=$NNODES \ + -e HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + -e NODE_RANK=$NODE_RANK \ + -e JAX_DISTRIBUTED_INITIALIZATION_TIMEOUT_SECONDS=1800 \ + -w '"${MAXDIFFUSION_DIR_IN_DOCKER}"' \ + '"${IMAGE_TAG}"' \ + /bin/bash -c " + # Error handling + set -ex + set -o pipefail # Make sure pipe failures are caught + trap \"echo \\\"Error on line \$LINENO\\\"\" ERR + + echo '"${IMAGE_TAG}"' + echo \"Starting node \$NODE_RANK of \$NNODES\" + echo \"Coordinator IP: \$JAX_COORDINATOR_IP\" + + # Create output directory + export BASE_OUTPUT_DIRECTORY=\"'"${OUTPUT_DIR_IN_DOCKER}"'\" + mkdir -p \${BASE_OUTPUT_DIRECTORY} + + cd '"${MAXDIFFUSION_DIR_IN_DOCKER}"' + bash launch.sh LOG_PATH='"${OUTPUT_DIR_IN_DOCKER}"' + + " +' + +srun bash -c ' + docker image rm -f '"$IMAGE_TAG"' +' diff --git a/multi_node/wan_multinode_train_build_docker.sh b/multi_node/wan_multinode_train_build_docker.sh new file mode 100644 index 00000000..dcfcb228 --- /dev/null +++ b/multi_node/wan_multinode_train_build_docker.sh @@ -0,0 +1,118 @@ +#!/bin/bash +# Multi-node JAX training without SLURM +# Usage: bash wan_multinode_train_no_slurm.sh "node1,node2,node3" +# +# Environment Variables: +# MULTI_NODES_LOG_DIR - Base directory for logs and outputs (default: /home/amd/jianhan/github/maxdiffusion/multi_node) +# +# Example with custom log directory: +# MULTI_NODES_LOG_DIR=/custom/path bash wan_multinode_train_no_slurm.sh "node1,node2" + +set -e + +# ============================================================================ +# CONFIGURATION +# ============================================================================ + +# Node list - comma separated hostnames +# Can be passed as first argument or hardcoded here +if [ -z "$1" ]; then + # Default node list (edit this) + NODE_LIST="core42-2,core42-4" +else + NODE_LIST="$1" +fi + +# Convert comma-separated list to array +IFS=',' read -ra NODES <<< "$NODE_LIST" +NNODES=${#NODES[@]} + +# Coordinator is the first node +COORDINATOR_IP="${NODES[0]}" + +# Paths and configuration +timestamp=$(date +%Y%m%d-%H%M%S) +IMAGE_TAG="jianhan-wan-multinode-train:v1" + +# Base directory for all multi-node logs and outputs (configurable) +MULTI_NODES_LOG_DIR="${MULTI_NODES_LOG_DIR:-/home/amd/jianhan/multi_node_log}" + +echo "========================================" +echo "Multi-node Training Configuration" +echo "========================================" +echo "Total nodes: $NNODES" +echo "Node list: ${NODES[@]}" +echo "Coordinator: $COORDINATOR_IP" +echo "Base log directory: $MULTI_NODES_LOG_DIR" +echo "========================================" + +SHARE_DOCKERFILE_PATH="/home/amd/jianhan/github/maxdiffusion/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile" +SHARED_CODE_BASE_PATH="/home/amd/jianhan/github/maxdiffusion" +MAXDIFFUSION_DIR_IN_DOCKER="/app/maxdiffusion" + +# Experiment name +export EXP_NAME="BUILD_DOCKER_${NNODES}N_${timestamp}" + +# Directories under MULTI_NODES_LOG_DIR +export LOG_DIR="${MULTI_NODES_LOG_DIR}/slurm_logs/${EXP_NAME}/${NNODES}-nodes" + +# Create directories locally +mkdir -p ${LOG_DIR} + +# Log files +export HOST_OUTPUT_LOG="${LOG_DIR}/host_output.out" +export HOST_ERROR_LOG="${LOG_DIR}/host_output.err" + +# Redirect output +exec > >(tee -a "${HOST_OUTPUT_LOG}") 2> >(tee -a "${HOST_ERROR_LOG}" >&2) + +echo "Log directory: $LOG_DIR" +echo "" + +# ============================================================================ +# STEP 2: BUILD DOCKER IMAGE ON ALL NODES +# ============================================================================ + +echo "" +echo "========================================" +echo "STEP 2: Building Docker image on all nodes" +echo "========================================" + +build_pids=() +for node in "${NODES[@]}"; do + echo "Building image on $node..." + ssh "$node" "bash -c ' + MAX_RETRIES=5 + INITIAL_DELAY=30 + MAX_DELAY=180 + RETRY_COUNT=0 + + while true; do + echo \"Node \$(hostname): Building image $IMAGE_TAG (Attempt \$((RETRY_COUNT + 1)))\" + docker build --tag $IMAGE_TAG --file $SHARE_DOCKERFILE_PATH $(dirname $SHARE_DOCKERFILE_PATH) + + if [ \$? -eq 0 ]; then + echo \"Image built successfully on \$(hostname)\" + break + else + RETRY_COUNT=\$((RETRY_COUNT + 1)) + if [ \"\$RETRY_COUNT\" -ge \"\$MAX_RETRIES\" ]; then + echo \"Failed to build image after \$MAX_RETRIES attempts. Exiting.\" + exit 1 + fi + + CURRENT_DELAY=\$((INITIAL_DELAY * (2 ** (RETRY_COUNT - 1)))) + if [ \"\$CURRENT_DELAY\" -gt \"\$MAX_DELAY\" ]; then + CURRENT_DELAY=\"\$MAX_DELAY\" + fi + + echo \"Build failed. Retrying in \$CURRENT_DELAY seconds...\" + sleep \"\$CURRENT_DELAY\" + fi + done + '" > "${LOG_DIR}/build_${node}.log" 2>&1 & + build_pids+=($!) +done + + + diff --git a/multi_node/wan_multinode_train_clean.sh b/multi_node/wan_multinode_train_clean.sh new file mode 100644 index 00000000..df084733 --- /dev/null +++ b/multi_node/wan_multinode_train_clean.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# Multi-node JAX training without SLURM +# Usage: bash wan_multinode_train_no_slurm.sh "node1,node2,node3" +# +# Environment Variables: +# MULTI_NODES_LOG_DIR - Base directory for logs and outputs (default: /home/amd/jianhan/github/maxdiffusion/multi_node) +# +# Example with custom log directory: +# MULTI_NODES_LOG_DIR=/custom/path bash wan_multinode_train_no_slurm.sh "node1,node2" + +set -e + +# ============================================================================ +# CONFIGURATION +# ============================================================================ + +# Node list - comma separated hostnames +# Can be passed as first argument or hardcoded here +if [ -z "$1" ]; then + # Default node list (edit this) + NODE_LIST="core42-2,core42-4" +else + NODE_LIST="$1" +fi + +# Convert comma-separated list to array +IFS=',' read -ra NODES <<< "$NODE_LIST" +NNODES=${#NODES[@]} + +# Coordinator is the first node +COORDINATOR_IP="${NODES[0]}" + +# Paths and configuration +timestamp=$(date +%Y%m%d-%H%M%S) +IMAGE_TAG="jianhan-wan-multinode-train:v1" + +# Base directory for all multi-node logs and outputs (configurable) +MULTI_NODES_LOG_DIR="${MULTI_NODES_LOG_DIR:-/home/amd/jianhan/multi_node_log}" + +echo "========================================" +echo "Multi-node Training Configuration" +echo "========================================" +echo "Total nodes: $NNODES" +echo "Node list: ${NODES[@]}" +echo "Coordinator: $COORDINATOR_IP" +echo "Base log directory: $MULTI_NODES_LOG_DIR" +echo "========================================" + +SHARE_DOCKERFILE_PATH="/home/amd/jianhan/github/maxdiffusion/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile" +SHARED_CODE_BASE_PATH="/home/amd/jianhan/github/maxdiffusion" +MAXDIFFUSION_DIR_IN_DOCKER="/app/maxdiffusion" + +# Experiment name +export EXP_NAME="CLEAN_${NNODES}N_${timestamp}" + +# Directories under MULTI_NODES_LOG_DIR +export LOG_DIR="${MULTI_NODES_LOG_DIR}/slurm_logs/${EXP_NAME}/${NNODES}-nodes" + +# Create directories locally +mkdir -p ${LOG_DIR} + +# Log files +export HOST_OUTPUT_LOG="${LOG_DIR}/host_output.out" +export HOST_ERROR_LOG="${LOG_DIR}/host_output.err" + +# Redirect output +exec > >(tee -a "${HOST_OUTPUT_LOG}") 2> >(tee -a "${HOST_ERROR_LOG}" >&2) + +echo "Log directory: $LOG_DIR" +echo "" + +# ============================================================================ +# STEP 1: CLEANUP EXISTING CONTAINERS +# ============================================================================ + +echo "" +echo "========================================" +echo "STEP 1: Cleaning up existing containers" +echo "========================================" + +# docker image rm -f '"$IMAGE_TAG"' 2>/dev/null || true +for node in "${NODES[@]}"; do + echo "Cleaning containers on $node..." + ssh "$node" 'bash -c '\'' + echo "Cleaning up on $(hostname)..." + docker stop $(docker ps -q) 2>/dev/null || true + docker rm $(docker ps -aq) 2>/dev/null || true + docker image rm -f '"$IMAGE_TAG"' 2>/dev/null || true + echo "Cleanup completed on $(hostname)" + '\''' & +done + +echo "Container cleanup completed on all nodes" + +# ============================================================================ +# STEP 2: SYNC CODEBASE TO ALL NODES +# ============================================================================ + +echo "" +echo "========================================" +echo "STEP 2: Syncing codebase to all nodes" +echo "========================================" + +for node in "${NODES[@]}"; do + echo "Syncing $SHARED_CODE_BASE_PATH to $node..." + ssh "$node" "mkdir -p $(dirname $SHARED_CODE_BASE_PATH)" + rsync -az --delete -e "ssh" "$SHARED_CODE_BASE_PATH/" "$node:$SHARED_CODE_BASE_PATH/" + echo "✓ Synced to $node" +done + +echo "Codebase sync completed on all nodes" +echo "" + diff --git a/multi_node/wan_multinode_train_launch.sh b/multi_node/wan_multinode_train_launch.sh new file mode 100644 index 00000000..f0a2ac85 --- /dev/null +++ b/multi_node/wan_multinode_train_launch.sh @@ -0,0 +1,155 @@ +#!/bin/bash +# Multi-node JAX training without SLURM +# Usage: bash wan_multinode_train_no_slurm.sh "node1,node2,node3" +# +# Environment Variables: +# MULTI_NODES_LOG_DIR - Base directory for logs and outputs (default: /home/amd/jianhan/github/maxdiffusion/multi_node) +# +# Example with custom log directory: +# MULTI_NODES_LOG_DIR=/custom/path bash wan_multinode_train_no_slurm.sh "node1,node2" + +set -e + +# ============================================================================ +# CONFIGURATION +# ============================================================================ + +# Node list - comma separated hostnames +# Can be passed as first argument or hardcoded here +if [ -z "$1" ]; then + # Default node list (edit this) + NODE_LIST="core42-2,core42-4" +else + NODE_LIST="$1" +fi + +# Convert comma-separated list to array +IFS=',' read -ra NODES <<< "$NODE_LIST" +NNODES=${#NODES[@]} + +# Coordinator is the first node +COORDINATOR_IP="${NODES[0]}" + +# Paths and configuration +timestamp=$(date +%Y%m%d-%H%M%S) +IMAGE_TAG="jianhan-wan-multinode-train:v1" + +# Base directory for all multi-node logs and outputs (configurable) +MULTI_NODES_LOG_DIR="${MULTI_NODES_LOG_DIR:-/home/amd/jianhan/multi_node_log}" + +echo "========================================" +echo "Multi-node Training Configuration" +echo "========================================" +echo "Total nodes: $NNODES" +echo "Node list: ${NODES[@]}" +echo "Coordinator: $COORDINATOR_IP" +echo "Base log directory: $MULTI_NODES_LOG_DIR" +echo "========================================" + +SHARE_DOCKERFILE_PATH="/home/amd/jianhan/github/maxdiffusion/multi_node/docker/jax_maxdiffusion_wan2.1_train_inference.ubuntu.amd.Dockerfile" +SHARED_CODE_BASE_PATH="/home/amd/jianhan/github/maxdiffusion" +MAXDIFFUSION_DIR_IN_DOCKER="/app/maxdiffusion" + +# Experiment name +export EXP_NAME="WAN_1_3B_FSDP8_${NNODES}N_${timestamp}" + +# Directories under MULTI_NODES_LOG_DIR +export LOG_DIR="${MULTI_NODES_LOG_DIR}/slurm_logs/${EXP_NAME}/${NNODES}-nodes" +OUTPUT_DIR="${MULTI_NODES_LOG_DIR}/output/${EXP_NAME}/${NNODES}-nodes" +OUTPUT_DIR_IN_DOCKER="/app/output-${EXP_NAME}/${NNODES}-nodes" + +# Create directories locally +mkdir -p ${LOG_DIR} +mkdir -p ${OUTPUT_DIR} +mkdir -p ${OUTPUT_DIR}/configs/models + +# Log files +export HOST_OUTPUT_LOG="${LOG_DIR}/host_output.out" +export HOST_ERROR_LOG="${LOG_DIR}/host_output.err" + +# Redirect output +exec > >(tee -a "${HOST_OUTPUT_LOG}") 2> >(tee -a "${HOST_ERROR_LOG}" >&2) + +echo "Log directory: $LOG_DIR" +echo "Output directory: $OUTPUT_DIR" +echo "" + +# ============================================================================ +# STEP 3: LAUNCH TRAINING CONTAINERS +# ============================================================================ + +echo "" +echo "========================================" +echo "STEP 3: Launching training containers" +echo "========================================" +echo "Coordinator IP: $COORDINATOR_IP" +echo "JAX Coordinator Port: 12345" +echo "" + +# --volume /dev/infiniband:/dev/infiniband \ + +launch_pids=() +for i in "${!NODES[@]}"; do + NODE="${NODES[$i]}" + NODE_RANK=$i + + echo "Launching training on $NODE (rank $NODE_RANK/$NNODES)..." + + ssh "$NODE" "bash -c ' + docker run --rm --privileged --network host \ + --cap-add=IPC_LOCK \ + --tmpfs /dev/shm:size=200G \ + --volume $OUTPUT_DIR:$OUTPUT_DIR_IN_DOCKER \ + --volume $SHARED_CODE_BASE_PATH:$MAXDIFFUSION_DIR_IN_DOCKER \ + -e JAX_COORDINATOR_IP=$COORDINATOR_IP \ + -e JAX_COORDINATOR_PORT=12345 \ + -e NNODES=$NNODES \ + -e HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ + -e NODE_RANK=$NODE_RANK \ + -e JAX_DISTRIBUTED_INITIALIZATION_TIMEOUT_SECONDS=1800 \ + -w $MAXDIFFUSION_DIR_IN_DOCKER \ + $IMAGE_TAG \ + /bin/bash -c \" + set -ex + set -o pipefail + trap \\\"echo \\\\\\\"Error on line \\\$LINENO\\\\\\\"\\\" ERR + + echo \\\"Image: $IMAGE_TAG\\\" + echo \\\"Starting node $NODE_RANK of $NNODES\\\" + echo \\\"Coordinator IP: \\\$JAX_COORDINATOR_IP\\\" + echo \\\"Node: \\\$(hostname)\\\" + + # Create output directory + export BASE_OUTPUT_DIRECTORY=\\\"$OUTPUT_DIR_IN_DOCKER\\\" + mkdir -p \\\${BASE_OUTPUT_DIRECTORY} + + chmod 777 $MAXDIFFUSION_DIR_IN_DOCKER -R + cd $MAXDIFFUSION_DIR_IN_DOCKER + bash launch.sh LOG_PATH=$OUTPUT_DIR_IN_DOCKER + \" + '" > "${LOG_DIR}/node_${NODE}_rank_${NODE_RANK}.log" 2>&1 & + + launch_pids+=($!) + + # Small delay between launches to avoid race conditions + sleep 2 +done + +echo "" +echo "All containers launched. Waiting for completion..." +echo "Monitor logs in: $LOG_DIR" +echo "" + +# ============================================================================ +# SUMMARY +# ============================================================================ + +echo "" +echo "========================================" +echo "Training Launched!" +echo "========================================" +echo "Experiment: $EXP_NAME" +echo "Nodes: $NNODES" +echo "Logs: $LOG_DIR" +echo "Output: $OUTPUT_DIR" + diff --git a/src/maxdiffusion/__init__.py b/src/maxdiffusion/__init__.py index 42e50d77..a1a2c2f5 100644 --- a/src/maxdiffusion/__init__.py +++ b/src/maxdiffusion/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" __version__ = "0.22.0.dev0" @@ -84,25 +84,23 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["models"].extend( - [ - "AsymmetricAutoencoderKL", - "AutoencoderKL", - "AutoencoderTiny", - "ControlNetModel", - "ModelMixin", - "MultiAdapter", - "PriorTransformer", - "T2IAdapter", - "T5FilmDecoder", - "Transformer2DModel", - "UNet1DModel", - "UNet2DConditionModel", - "UNet2DModel", - "UNet3DConditionModel", - "VQModel", - ] - ) + _import_structure["models"].extend([ + "AsymmetricAutoencoderKL", + "AutoencoderKL", + "AutoencoderTiny", + "ControlNetModel", + "ModelMixin", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", + "UNet1DModel", + "UNet2DConditionModel", + "UNet2DModel", + "UNet3DConditionModel", + "VQModel", + ]) _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -113,56 +111,52 @@ "get_scheduler", ] - _import_structure["pipelines"].extend( - [ - "AudioPipelineOutput", - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - "ConsistencyModelPipeline", - "DanceDiffusionPipeline", - "DDIMPipeline", - "DDPMPipeline", - "DiffusionPipeline", - "DiTPipeline", - "ImagePipelineOutput", - "KarrasVePipeline", - "LDMPipeline", - "LDMSuperResolutionPipeline", - "PNDMPipeline", - "RePaintPipeline", - "ScoreSdeVePipeline", - ] - ) - _import_structure["schedulers"].extend( - [ - "CMStochasticIterativeScheduler", - "DDIMInverseScheduler", - "DDIMParallelScheduler", - "DDIMScheduler", - "DDPMParallelScheduler", - "DDPMScheduler", - "DDPMWuerstchenScheduler", - "DEISMultistepScheduler", - "DPMSolverMultistepInverseScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "EulerAncestralDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "IPNDMScheduler", - "KarrasVeScheduler", - "KDPM2AncestralDiscreteScheduler", - "KDPM2DiscreteScheduler", - "PNDMScheduler", - "RePaintScheduler", - "SchedulerMixin", - "ScoreSdeVeScheduler", - "UnCLIPScheduler", - "UniPCMultistepScheduler", - "VQDiffusionScheduler", - ] - ) + _import_structure["pipelines"].extend([ + "AudioPipelineOutput", + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + "ConsistencyModelPipeline", + "DanceDiffusionPipeline", + "DDIMPipeline", + "DDPMPipeline", + "DiffusionPipeline", + "DiTPipeline", + "ImagePipelineOutput", + "KarrasVePipeline", + "LDMPipeline", + "LDMSuperResolutionPipeline", + "PNDMPipeline", + "RePaintPipeline", + "ScoreSdeVePipeline", + ]) + _import_structure["schedulers"].extend([ + "CMStochasticIterativeScheduler", + "DDIMInverseScheduler", + "DDIMParallelScheduler", + "DDIMScheduler", + "DDPMParallelScheduler", + "DDPMScheduler", + "DDPMWuerstchenScheduler", + "DEISMultistepScheduler", + "DPMSolverMultistepInverseScheduler", + "DPMSolverMultistepScheduler", + "DPMSolverSinglestepScheduler", + "EulerAncestralDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "IPNDMScheduler", + "KarrasVeScheduler", + "KDPM2AncestralDiscreteScheduler", + "KDPM2DiscreteScheduler", + "PNDMScheduler", + "RePaintScheduler", + "SchedulerMixin", + "ScoreSdeVeScheduler", + "UnCLIPScheduler", + "UniPCMultistepScheduler", + "VQDiffusionScheduler", + ]) _import_structure["training_utils"] = ["EMAModel"] try: @@ -202,100 +196,98 @@ ] else: - _import_structure["pipelines"].extend( - [ - "AltDiffusionImg2ImgPipeline", - "AltDiffusionPipeline", - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - "AudioLDMPipeline", - "BlipDiffusionControlNetPipeline", - "BlipDiffusionPipeline", - "CLIPImageProjection", - "CycleDiffusionPipeline", - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - "ImageTextPipelineOutput", - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - "LDMTextToImagePipeline", - "MusicLDMPipeline", - "PaintByExamplePipeline", - "SemanticStableDiffusionPipeline", - "ShapEImg2ImgPipeline", - "ShapEPipeline", - "StableDiffusionAdapterPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPipelineSafe", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableDiffusionXLAdapterPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "UnCLIPImageVariationPipeline", - "UnCLIPPipeline", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - "VideoToVideoSDPipeline", - "VQDiffusionPipeline", - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "AltDiffusionImg2ImgPipeline", + "AltDiffusionPipeline", + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", + "CLIPImageProjection", + "CycleDiffusionPipeline", + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + "ImageTextPipelineOutput", + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + "LDMTextToImagePipeline", + "MusicLDMPipeline", + "PaintByExamplePipeline", + "SemanticStableDiffusionPipeline", + "ShapEImg2ImgPipeline", + "ShapEPipeline", + "StableDiffusionAdapterPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPipelineSafe", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableDiffusionXLAdapterPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "UnCLIPImageVariationPipeline", + "UnCLIPPipeline", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + "VideoToVideoSDPipeline", + "VQDiffusionPipeline", + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ]) try: if not (is_torch_available() and is_k_diffusion_available()): @@ -321,16 +313,14 @@ ] else: - _import_structure["pipelines"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ]) try: if not (is_torch_available() and is_librosa_available()): @@ -376,19 +366,17 @@ _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) - _import_structure["schedulers"].extend( - [ - "FlaxDDIMScheduler", - "FlaxDDPMScheduler", - "FlaxDPMSolverMultistepScheduler", - "FlaxEulerDiscreteScheduler", - "FlaxKarrasVeScheduler", - "FlaxLMSDiscreteScheduler", - "FlaxPNDMScheduler", - "FlaxSchedulerMixin", - "FlaxScoreSdeVeScheduler", - ] - ) + _import_structure["schedulers"].extend([ + "FlaxDDIMScheduler", + "FlaxDDPMScheduler", + "FlaxDPMSolverMultistepScheduler", + "FlaxEulerDiscreteScheduler", + "FlaxKarrasVeScheduler", + "FlaxLMSDiscreteScheduler", + "FlaxPNDMScheduler", + "FlaxSchedulerMixin", + "FlaxScoreSdeVeScheduler", + ]) try: @@ -403,16 +391,14 @@ else: - _import_structure["pipelines"].extend( - [ - "FlaxStableDiffusionControlNetPipeline", - "FlaxStableDiffusionXLControlNetPipeline", - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["pipelines"].extend([ + "FlaxStableDiffusionControlNetPipeline", + "FlaxStableDiffusionXLControlNetPipeline", + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + "FlaxStableDiffusionXLPipeline", + ]) try: if not (is_note_seq_available()): diff --git a/src/maxdiffusion/checkpointing/__init__.py b/src/maxdiffusion/checkpointing/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/checkpointing/__init__.py +++ b/src/maxdiffusion/checkpointing/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py index 9faba8bc..baf5bdd6 100644 --- a/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py +++ b/src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from abc import ABC from contextlib import nullcontext @@ -66,7 +66,6 @@ def __init__(self, config, checkpoint_type): ) def _create_optimizer(self, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) @@ -74,7 +73,6 @@ def _create_optimizer(self, config, learning_rate): return tx, learning_rate_scheduler def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training): - tx, learning_rate_scheduler = None, None if is_training: learning_rate = self.config.learning_rate @@ -96,7 +94,6 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training) return unet_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): - # Currently VAE training is not supported. weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) return max_utils.setup_initial_state( @@ -112,7 +109,6 @@ def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=F ) def create_text_encoder_state(self, pipeline, params, checkpoint_item_name, is_training): - tx = None if is_training: learning_rate = self.config.text_encoder_learning_rate @@ -260,11 +256,9 @@ def config_to_json(model_or_config): self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step=None): - self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX def load_checkpoint(self, step=None, scheduler_class=None): - pipeline_class = self._get_pipeline_class() self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index bbad3ad1..960e0692 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -1,19 +1,19 @@ # ruff: noqa """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" diff --git a/src/maxdiffusion/checkpointing/flux_checkpointer.py b/src/maxdiffusion/checkpointing/flux_checkpointer.py old mode 100644 new mode 100755 index 89ac3764..c18a0589 --- a/src/maxdiffusion/checkpointing/flux_checkpointer.py +++ b/src/maxdiffusion/checkpointing/flux_checkpointer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from abc import ABC from contextlib import nullcontext @@ -20,6 +20,7 @@ import json import jax from jax.sharding import Mesh +from flax.traverse_util import flatten_dict, unflatten_dict import orbax.checkpoint as ocp import grain.python as grain from maxdiffusion import ( @@ -67,7 +68,6 @@ def __init__(self, config, checkpoint_type): ) def _create_optimizer(self, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps ) @@ -104,12 +104,16 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training) training=is_training, ) if not self.config.train_new_flux: - flux_state = flux_state.replace(params=transformer_params) - flux_state = jax.device_put(flux_state, state_mesh_shardings) + with self.mesh: + flat_state_shardings = flatten_dict(state_mesh_shardings.params) + param_state = flatten_dict(flux_state.params) + for path, val in flatten_dict(transformer_params).items(): + sharding = flat_state_shardings[path] + param_state[path].value = max_utils.device_put_replicated(val, sharding) + flux_state = flux_state.replace(params=unflatten_dict(param_state)) return flux_state, state_mesh_shardings, learning_rate_scheduler def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False): - # Currently VAE training is not supported. weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=self.rng) return max_utils.setup_initial_state( @@ -163,7 +167,6 @@ def config_to_json(model_or_config): self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) def load_params(self, step=None): - self.checkpoint_format = _CHECKPOINT_FORMAT_ORBAX def load_flux_configs_from_orbax(self, step): @@ -243,7 +246,6 @@ def load_diffusers_checkpoint(self): return pipeline, params def load_checkpoint(self, step=None, scheduler_class=None): - model_configs = self.load_flux_configs_from_orbax(step) pipeline, params = None, {} diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 006b3ec8..4ab90971 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from abc import ABC, abstractmethod @@ -35,14 +35,12 @@ def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT): self.checkpoint_type = checkpoint_type self.opt_state = None - self.checkpoint_manager: ocp.CheckpointManager = ( - create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, - ) + self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, ) def _create_optimizer(self, model, config, learning_rate): @@ -61,13 +59,18 @@ def load_diffusers_checkpoint(self): raise NotImplementedError @abstractmethod - def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int]]: + def load_checkpoint( + self, step=None + ) -> Tuple[ + Optional[WanPipeline2_1 | WanPipeline2_2 | WanPipelineI2V_2_1 | WanPipelineI2V_2_2], Optional[dict], Optional[int] + ]: raise NotImplementedError @abstractmethod def save_checkpoint(self, train_step, pipeline, train_states: dict): raise NotImplementedError + def save_checkpoint_orig(self, train_step, pipeline, train_states: dict): """Saves the training state and model configurations.""" diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py index a8e2a297..da30567b 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointer2_1(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 30cff387..533a00db 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointer2_2(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: @@ -38,7 +39,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle low_noise_transformer low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + abstract_tree_structure_low_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata + ) low_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -48,7 +51,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle high_noise_transformer high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + abstract_tree_structure_high_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata + ) high_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -67,10 +72,18 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log( + f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}" + ) max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py index 6f4bbc90..5850692f 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointerI2V_2_1(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py index a55048cf..98f76f48 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import json @@ -24,6 +24,7 @@ from etils import epath from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer + class WanCheckpointerI2V_2_2(WanCheckpointer): def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: @@ -38,7 +39,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle low_noise_transformer low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + abstract_tree_structure_low_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata + ) low_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -48,7 +51,9 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic # Handle high_noise_transformer high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + abstract_tree_structure_high_params = jax.tree_util.tree_map( + ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata + ) high_params_restore = ocp.args.PyTreeRestore( restore_args=jax.tree.map( lambda _: ocp.RestoreArgs(restore_type=np.ndarray), @@ -67,10 +72,18 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic ), ) max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log( + f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}" + ) + max_logging.log( + f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}" + ) max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") return restored_checkpoint, step diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 15553727..e85d270a 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -36,6 +36,7 @@ # Physical axis names for device meshes. DATA = "data" FSDP = "fsdp" +CONTEXT = "context" TENSOR = "tensor" # Logical axis names for model parameters and activations. BATCH = "activation_batch" @@ -66,19 +67,19 @@ ### Common axis rules for ring attention ### RING_ATTENTION_AXIS_RULES = [ - [SELF_ATTN_HEAD, None], - [SELF_ATTN_Q_LENGTH, FSDP], - [SELF_ATTN_KV_LENGTH, FSDP], - [CROSS_ATTN_HEAD, None], - [CROSS_ATTN_Q_LENGTH, FSDP], - [CROSS_ATTN_KV_LENGTH, FSDP], + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, CONTEXT], + [SELF_ATTN_KV_LENGTH, CONTEXT], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, CONTEXT], + [CROSS_ATTN_KV_LENGTH, CONTEXT], ] SEQUENCE_PARALLEL_AXIS_RULES = [ - [SELF_ATTN_HEAD, None], - [SELF_ATTN_Q_LENGTH, FSDP], - [SELF_ATTN_KV_LENGTH, None], - [CROSS_ATTN_HEAD, None], - [CROSS_ATTN_Q_LENGTH, FSDP], - [CROSS_ATTN_KV_LENGTH, None], + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, CONTEXT], + [SELF_ATTN_KV_LENGTH, None], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, CONTEXT], + [CROSS_ATTN_KV_LENGTH, None], ] diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index 7bd8ae70..ca2579d9 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -106,7 +106,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -131,7 +131,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -139,9 +139,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 24dffe40..65e7d19e 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -108,7 +108,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -133,7 +133,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -141,9 +141,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 7b224058..16948296 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -121,7 +121,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -146,7 +146,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -154,9 +154,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 0036b363..7a508095 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -158,7 +158,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -166,9 +166,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False @@ -177,7 +179,20 @@ allow_split_physical_axes: False # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' train_split: 'train' -dataset_type: 'tf' +dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic' +# ============================================================================== +# Synthetic Data Configuration (only used when dataset_type='synthetic') +# ============================================================================== +# To use synthetic data for testing/debugging without real datasets: +# 1. Set dataset_type: 'synthetic' above +# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000) +# 3. Optionally override dimensions +# +# synthetic_num_samples: null # null for infinite, or set a number +# +# Optional dimension overrides: +# resolution: 512 +# ============================================================================== cache_latents_text_encoder_outputs: True # cache_latents_text_encoder_outputs only apply to dataset_type="tf", # only apply to small dataset that fits in memory diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index ac0a0f8c..1aba7431 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -132,7 +132,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -158,7 +158,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -166,9 +166,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index c60dd79e..9ae39971 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -140,7 +140,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -166,7 +166,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -174,9 +174,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index b2a11dba..91a3e092 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -151,7 +151,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -166,31 +166,33 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_in : conv.shape[2] weight # conv_out : conv.shape[-1] weight logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_self_attn_heads', ['fsdp', 'tensor']], - ['activation_cross_attn_q_length', ['fsdp', 'tensor']], - ['activation_length', 'fsdp'], + ['batch', ['data', 'fsdp']], + ['activation_batch', ['data', 'fsdp']], + ['activation_self_attn_heads', ['context', 'tensor']], + ['activation_cross_attn_q_length', ['context', 'tensor']], + ['activation_length', 'context'], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed', ['context', 'fsdp']], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data', 'context', 'fsdp']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'context'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 +dcn_fsdp_parallelism: 1 +dcn_context_parallelism: -1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_fsdp_parallelism: 1 +ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 allow_split_physical_axes: False @@ -199,7 +201,28 @@ allow_split_physical_axes: False # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' train_split: 'train' -dataset_type: 'tfrecord' +dataset_type: 'tfrecord' # Options: 'tfrecord', 'hf', 'tf', 'grain', 'synthetic' +# ============================================================================== +# Synthetic Data Configuration (only used when dataset_type='synthetic') +# ============================================================================== +# To use synthetic data for testing/debugging without real datasets: +# 1. Set dataset_type: 'synthetic' above +# 2. Optionally set synthetic_num_samples (null=infinite, or a number like 10000) +# 3. Optionally override dimensions with synthetic_override_* flags below +# +# synthetic_num_samples: null # null for infinite, or set a number +# +# Optional dimension overrides (comment out to use pipeline/config values): +# synthetic_override_height: 720 +# synthetic_override_width: 1280 +# synthetic_override_num_frames: 121 +# synthetic_override_max_sequence_length: 512 +# synthetic_override_text_embed_dim: 4096 +# synthetic_override_num_channels_latents: 16 +# synthetic_override_vae_scale_factor_spatial: 8 +# synthetic_override_vae_scale_factor_temporal: 4 +# ============================================================================== + cache_latents_text_encoder_outputs: True # cache_latents_text_encoder_outputs only apply to dataset_type="tf", # only apply to small dataset that fits in memory @@ -315,12 +338,14 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras"], + weight_name: ["wan2.1_t2v_14b_lora_rank64_lightx2v_4step.safetensors"], + adapter_name: ["wan21-distill-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index cff70a94..022b18c9 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -139,7 +139,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -154,30 +154,33 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_in : conv.shape[2] weight # conv_out : conv.shape[-1] weight logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_length', 'fsdp'], - + ['batch', ['data', 'fsdp']], + ['activation_batch', ['data', 'fsdp']], + ['activation_self_attn_heads', ['context', 'tensor']], + ['activation_cross_attn_q_length', ['context', 'tensor']], + ['activation_length', 'context'], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed', ['context', 'fsdp']], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data', 'context', 'fsdp']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'context'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded -dcn_fsdp_parallelism: -1 +dcn_fsdp_parallelism: 1 +dcn_context_parallelism: -1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 -ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_fsdp_parallelism: 1 +ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 allow_split_physical_axes: False @@ -297,7 +300,7 @@ guidance_scale_high: 4.0 # The timestep threshold. If `t` is at or above this value, # the `high_noise_model` is considered as the required model. # timestep to switch between low noise and high noise transformer -boundary_timestep: 875 +boundary_ratio: 0.875 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 @@ -313,12 +316,15 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"], + high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + adapter_name: ["wan22-distill-lora"], + scale: [1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 92a371e3..f8d2ff95 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -27,7 +27,7 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 -pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-480P-Diffusers' +pretrained_model_name_or_path: 'Wan-AI/Wan2.1-I2V-14B-720P-Diffusers' model_name: wan2.1 model_type: 'I2V' @@ -134,7 +134,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -149,21 +149,21 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_in : conv.shape[2] weight # conv_out : conv.shape[-1] weight logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_self_attn_heads', ['fsdp', 'tensor']], - ['activation_cross_attn_q_length', ['fsdp', 'tensor']], - ['activation_length', 'fsdp'], + ['batch', ['data', 'fsdp']], + ['activation_batch', ['data', 'fsdp']], + ['activation_self_attn_heads', ['context', 'tensor']], + ['activation_cross_attn_q_length', ['context', 'tensor']], + ['activation_length', 'context'], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed', ['context', 'fsdp']], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data', 'context', 'fsdp']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'context'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -171,9 +171,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False @@ -274,20 +276,20 @@ profiler_steps: 10 enable_jax_named_scopes: False # Generation parameters -prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." -prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -height: 480 -width: 832 +height: 720 +width: 1280 num_frames: 81 guidance_scale: 5.0 -flow_shift: 3.0 +flow_shift: 5.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 30 -fps: 24 +num_inference_steps: 50 +fps: 16 save_final_checkpoint: False # SDXL Lightning parameters @@ -298,12 +300,14 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64, 32], + lora_model_name_or_path: ["lightx2v/Wan2.1-Distill-Loras", "starsfriday/Wan2.1-Divine-Power-LoRA"], + weight_name: ["wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors", "divine-power.safetensors"], + adapter_name: ["wan21-distill-lora-i2v", "divine-power-lora"], + scale: [1.0, 1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index f8982b44..5aae2d5a 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -135,7 +135,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -150,21 +150,21 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_in : conv.shape[2] weight # conv_out : conv.shape[-1] weight logical_axis_rules: [ - ['batch', 'data'], - ['activation_batch', 'data'], - ['activation_self_attn_heads', ['fsdp', 'tensor']], - ['activation_cross_attn_q_length', ['fsdp', 'tensor']], - ['activation_length', 'fsdp'], + ['batch', ['data', 'fsdp']], + ['activation_batch', ['data', 'fsdp']], + ['activation_self_attn_heads', ['context', 'tensor']], + ['activation_cross_attn_q_length', ['context', 'tensor']], + ['activation_length', 'context'], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed','fsdp'], + ['embed', ['context', 'fsdp']], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data','fsdp']], + ['conv_batch', ['data', 'context', 'fsdp']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], + ['conv_out', 'context'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -172,9 +172,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False @@ -275,21 +277,21 @@ profiler_steps: 10 enable_jax_named_scopes: False # Generation parameters -prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." -prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt_2: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True -height: 480 -width: 832 +height: 720 +width: 1280 num_frames: 81 -flow_shift: 3.0 +flow_shift: 5.0 # Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py # guidance scale factor for low noise transformer -guidance_scale_low: 3.0 +guidance_scale_low: 3.0 # guidance scale factor for high noise transformer -guidance_scale_high: 4.0 +guidance_scale_high: 4.0 # The timestep threshold. If `t` is at or above this value, # the `high_noise_model` is considered as the required model. @@ -298,8 +300,8 @@ boundary_ratio: 0.875 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 -num_inference_steps: 30 -fps: 24 +num_inference_steps: 50 +fps: 16 save_final_checkpoint: False # SDXL Lightning parameters @@ -310,12 +312,15 @@ lightning_repo: "" lightning_ckpt: "" # LoRA parameters +enable_lora: False # Values are lists to support multiple LoRA loading during inference in the future. lora_config: { - lora_model_name_or_path: [], - weight_name: [], - adapter_name: [], - scale: [], + rank: [64, 16], + lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras", "ostris/wan22_i2v_14b_orbit_shot_lora"], + high_noise_weight_name: ["wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_high_noise.safetensors"], + low_noise_weight_name: ["wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", "wan22_14b_i2v_orbit_low_noise.safetensors"], # Empty or "wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors" + adapter_name: ["wan22-distill-lora", "wan22-orbit-lora"], + scale: [1.0, 1.0], from_pt: [] } # Ex with values: diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 49e53ae5..3dbb1578 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -106,7 +106,7 @@ base_output_directory: "" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -131,7 +131,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -139,9 +139,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 6f6662b0..e487559a 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -86,7 +86,7 @@ skip_jax_distributed_system: False base_output_directory: "" # Parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] # batch : batch dimension of data and activations # hidden : @@ -111,7 +111,7 @@ logical_axis_rules: [ ['out_channels', 'tensor'], ['conv_out', 'fsdp'], ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -119,9 +119,11 @@ data_sharding: [['data', 'fsdp', 'tensor']] # and product of the ICI axes should equal number of devices per slice. dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configs/ltx_video.yml b/src/maxdiffusion/configs/ltx_video.yml index 71316ea1..4328da18 100644 --- a/src/maxdiffusion/configs/ltx_video.yml +++ b/src/maxdiffusion/configs/ltx_video.yml @@ -1,6 +1,8 @@ #hardware hardware: 'tpu' skip_jax_distributed_system: False +attention: 'flash' +attention_sharding_uniform: True jax_cache_dir: '' weights_dtype: 'bfloat16' @@ -62,7 +64,7 @@ second_pass: cfg_star_rescale: True #parallelism -mesh_axes: ['data', 'fsdp', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] logical_axis_rules: [ ['batch', 'data'], ['activation_heads', 'fsdp'], @@ -77,12 +79,14 @@ logical_axis_rules: [ ['conv_out', 'fsdp'], ['conv_in', 'fsdp'] ] -data_sharding: [['data', 'fsdp', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: -1 +dcn_context_parallelism: 1 dcn_tensor_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_context_parallelism: 1 ici_tensor_parallelism: 1 allow_split_physical_axes: False diff --git a/src/maxdiffusion/configuration_utils.py b/src/maxdiffusion/configuration_utils.py index 3a495e02..0e8c9968 100644 --- a/src/maxdiffusion/configuration_utils.py +++ b/src/maxdiffusion/configuration_utils.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" ConfigMixin base class and utilities.""" +"""ConfigMixin base class and utilities.""" import dataclasses import functools import importlib @@ -611,7 +611,6 @@ def to_json_saveable(value): config_dict.pop(key) try: - json_str = json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomEncoder) except Exception as e: max_logging.log(f"Error serializing config to JSON: {e}") diff --git a/src/maxdiffusion/controlnet/__init__.py b/src/maxdiffusion/controlnet/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/controlnet/__init__.py +++ b/src/maxdiffusion/controlnet/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/controlnet/generate_controlnet_replicated.py b/src/maxdiffusion/controlnet/generate_controlnet_replicated.py index a3959cbb..bd4ef6eb 100644 --- a/src/maxdiffusion/controlnet/generate_controlnet_replicated.py +++ b/src/maxdiffusion/controlnet/generate_controlnet_replicated.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence from absl import app @@ -28,7 +28,6 @@ def run(config): - rng = jax.random.PRNGKey(config.seed) # get canny image diff --git a/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py b/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py index b38202c8..235159a9 100644 --- a/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py +++ b/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence from absl import app diff --git a/src/maxdiffusion/data_preprocessing/__init__.py b/src/maxdiffusion/data_preprocessing/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/data_preprocessing/__init__.py +++ b/src/maxdiffusion/data_preprocessing/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index e0191373..64e9d54b 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ @@ -71,7 +71,6 @@ def create_example(latent, hidden_states, timestep=None): def generate_dataset(config): - tfrecords_dir = config.tfrecords_dir if not os.path.exists(tfrecords_dir): os.makedirs(tfrecords_dir) diff --git a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py index ae0b15f4..23baaffb 100644 --- a/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py +++ b/src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ """ @@ -84,7 +84,6 @@ def vae_encode(video, rng, vae, vae_cache): def generate_dataset(config, pipeline): - tfrecords_dir = config.tfrecords_dir if not os.path.exists(tfrecords_dir): os.makedirs(tfrecords_dir) diff --git a/src/maxdiffusion/dreambooth/__init__.py b/src/maxdiffusion/dreambooth/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/dreambooth/__init__.py +++ b/src/maxdiffusion/dreambooth/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/dreambooth/dreambooth_constants.py b/src/maxdiffusion/dreambooth/dreambooth_constants.py index 72ac6003..bb366e15 100644 --- a/src/maxdiffusion/dreambooth/dreambooth_constants.py +++ b/src/maxdiffusion/dreambooth/dreambooth_constants.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" INSTANCE_IMAGES = "instance_images" INSTANCE_IMAGE_LATENTS = "instance_image_latents" diff --git a/src/maxdiffusion/dreambooth/train_dreambooth.py b/src/maxdiffusion/dreambooth/train_dreambooth.py index 5cb7e233..d9b17475 100644 --- a/src/maxdiffusion/dreambooth/train_dreambooth.py +++ b/src/maxdiffusion/dreambooth/train_dreambooth.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index ac4fbb7f..7b1f1f62 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import functools import time @@ -86,7 +86,6 @@ def tokenize(prompt, tokenizer): def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -132,7 +131,6 @@ def vae_decode(latents, state, pipeline): def run_inference(states, pipeline, params, config, rng, mesh, batch_size): - unet_state = states["unet_state"] vae_state = states["vae_state"] @@ -158,7 +156,6 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size): def run(config): - checkpoint_loader = GenerateSD(config, STABLE_DIFFUSION_CHECKPOINT) pipeline, params = checkpoint_loader.load_checkpoint() diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index b248156e..0ba8a7a8 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Callable, List, Union, Sequence @@ -137,7 +137,6 @@ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: flo def run_inference( states, transformer, vae, config, mesh, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): - transformer_state = states["transformer"] vae_state = states["vae"] @@ -175,7 +174,6 @@ def pack_latents( def prepare_latents( batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array ): - # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) @@ -223,7 +221,6 @@ def get_t5_prompt_embeds( text_encoder: T5EncoderModel, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( @@ -256,7 +253,6 @@ def encode_prompt( num_images_per_prompt: int = 1, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt prompt_2 = prompt or prompt_2 prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 diff --git a/src/maxdiffusion/generate_flux_multi_res.py b/src/maxdiffusion/generate_flux_multi_res.py index 7d07883c..33179295 100644 --- a/src/maxdiffusion/generate_flux_multi_res.py +++ b/src/maxdiffusion/generate_flux_multi_res.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import List, Union, Sequence @@ -154,7 +154,6 @@ def run_inference( p_ts, vae_scale_factor, ): - transformer_state = states["transformer"] vae_state = states["vae"] @@ -194,7 +193,6 @@ def pack_latents( def prepare_latents( batch_size: int, num_channels_latents: int, height: int, width: int, vae_scale_factor: int, dtype: jnp.dtype, rng: Array ): - # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) @@ -270,7 +268,6 @@ def get_t5_prompt_embeds( text_encoder: T5EncoderModel, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = tokenizer( @@ -303,7 +300,6 @@ def encode_prompt( num_images_per_prompt: int = 1, max_sequence_length: int = 512, ): - prompt = [prompt] if isinstance(prompt, str) else prompt prompt_2 = prompt or prompt_2 prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py index e6b8d4e2..c89f413a 100644 --- a/src/maxdiffusion/generate_flux_pipeline.py +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Sequence diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 6ecc6666..93753f0c 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import numpy as np from absl import app @@ -34,7 +34,6 @@ def calculate_padding( source_height: int, source_width: int, target_height: int, target_width: int ) -> tuple[int, int, int, int]: - # Calculate total padding needed pad_height = target_height - source_height pad_width = target_width - source_width diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index 9ad1022d..3ab70370 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import functools from absl import app @@ -115,7 +115,6 @@ def tokenize(prompt, pipeline): def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size): - data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding)) vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) @@ -189,7 +188,6 @@ def vae_decode(latents, state, pipeline): def run_inference(states, pipeline, params, config, rng, mesh, batch_size): - unet_state = states["unet_state"] vae_state = states["vae_state"] diff --git a/src/maxdiffusion/generate_sdxl_replicated.py b/src/maxdiffusion/generate_sdxl_replicated.py index d17fc02d..83df3a99 100644 --- a/src/maxdiffusion/generate_sdxl_replicated.py +++ b/src/maxdiffusion/generate_sdxl_replicated.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import time diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d3aad31d..7c03a21c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -28,6 +28,7 @@ from google.cloud import storage import flax from maxdiffusion.common_types import WAN2_1, WAN2_2 +from maxdiffusion.loaders.wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader def upload_video_to_gcs(output_dir: str, video_path: str): @@ -66,10 +67,11 @@ def delete_file(file_path: str): else: max_logging.log(f"The file '{file_path}' does not exist.") + def get_git_commit_hash(): """Tries to get the current Git commit hash.""" try: - commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8') + commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8") return commit_hash except subprocess.CalledProcessError: max_logging.log("Warning: 'git rev-parse HEAD' failed. Not running in a git repo?") @@ -78,8 +80,10 @@ def get_git_commit_hash(): max_logging.log("Warning: 'git' command not found.") return None + jax.config.update("jax_use_shardy_partitioner", True) + def call_pipeline(config, pipeline, prompt, negative_prompt): model_key = config.model_name model_type = config.model_type @@ -131,7 +135,6 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_inference_steps=config.num_inference_steps, guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, - boundary=config.boundary_timestep, ) else: raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}") @@ -159,13 +162,12 @@ def inference_generate_video(config, pipeline, filename_prefix=""): return -def run(config, pipeline=None, filename_prefix=""): +def run(config, pipeline=None, filename_prefix="", commit_hash=None): model_key = config.model_name writer = max_utils.initialize_summary_writer(config) if jax.process_index() == 0 and writer: max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") - commit_hash = get_git_commit_hash() if commit_hash: writer.add_text("inference/git_commit_hash", commit_hash, global_step=0) max_logging.log(f"Git Commit Hash: {commit_hash}") @@ -187,6 +189,43 @@ def run(config, pipeline=None, filename_prefix=""): else: raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") pipeline, _, _ = checkpoint_loader.load_checkpoint() + + # If LoRA is specified, inject layers and load weights. + if ( + config.enable_lora + and hasattr(config, "lora_config") + and config.lora_config + and config.lora_config["lora_model_name_or_path"] + ): + if model_key == WAN2_1: + lora_loader = Wan2_1NNXLoraLoader() + lora_config = config.lora_config + for i in range(len(lora_config["lora_model_name_or_path"])): + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][i], + transformer_weight_name=lora_config["weight_name"][i], + rank=lora_config["rank"][i], + scale=lora_config["scale"][i], + scan_layers=config.scan_layers, + dtype=config.weights_dtype, + ) + + if model_key == WAN2_2: + lora_loader = Wan2_2NNXLoraLoader() + lora_config = config.lora_config + for i in range(len(lora_config["lora_model_name_or_path"])): + pipeline = lora_loader.load_lora_weights( + pipeline, + lora_config["lora_model_name_or_path"][i], + high_noise_weight_name=lora_config["high_noise_weight_name"][i], + low_noise_weight_name=lora_config["low_noise_weight_name"][i], + rank=lora_config["rank"][i], + scale=lora_config["scale"][i], + scan_layers=config.scan_layers, + dtype=config.weights_dtype, + ) + s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables @@ -247,9 +286,13 @@ def run(config, pipeline=None, filename_prefix=""): def main(argv: Sequence[str]) -> None: + commit_hash = get_git_commit_hash() pyconfig.initialize(argv) - flax.config.update("flax_always_shard_variable", False) - run(pyconfig.config) + try: + flax.config.update("flax_always_shard_variable", False) + except LookupError: + pass + run(pyconfig.config, commit_hash=commit_hash) if __name__ == "__main__": diff --git a/src/maxdiffusion/input_pipeline/__init__.py b/src/maxdiffusion/input_pipeline/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/input_pipeline/__init__.py +++ b/src/maxdiffusion/input_pipeline/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/input_pipeline/_grain_data_processing.py b/src/maxdiffusion/input_pipeline/_grain_data_processing.py index 5ba3b637..6498b263 100644 --- a/src/maxdiffusion/input_pipeline/_grain_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_grain_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import dataclasses import glob diff --git a/src/maxdiffusion/input_pipeline/_hf_data_processing.py b/src/maxdiffusion/input_pipeline/_hf_data_processing.py index e0f1d725..10f276d6 100644 --- a/src/maxdiffusion/input_pipeline/_hf_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_hf_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import warnings import datasets diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index b8992415..dae9a3a1 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import tensorflow as tf @@ -41,7 +41,6 @@ def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_cou def make_tf_iterator( config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, tokenize_fn, image_transforms_fn ): - if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location): train_ds = load_from_disk(config.dataset_save_location) else: diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index 27f2ad25..6c00e0f6 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os from functools import partial @@ -23,6 +23,7 @@ from maxdiffusion.input_pipeline import _hf_data_processing from maxdiffusion.input_pipeline import _grain_data_processing from maxdiffusion.input_pipeline import _tfds_data_processing +from maxdiffusion.input_pipeline import synthetic_data_iterator from maxdiffusion import multihost_dataloading from maxdiffusion.maxdiffusion_utils import tokenize_captions, transform_images, vae_apply from maxdiffusion.dreambooth.dreambooth_constants import ( @@ -54,8 +55,9 @@ def make_data_iterator( feature_description=None, prepare_sample_fn=None, is_training=True, + pipeline=None, ): - """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)""" + """Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord, grain, synthetic)""" if config.dataset_type == "hf" or config.dataset_type == "tf": if tokenize_fn is None or image_transforms_fn is None: @@ -110,8 +112,16 @@ def make_data_iterator( prepare_sample_fn, is_training, ) + elif config.dataset_type == "synthetic": + return synthetic_data_iterator.make_synthetic_iterator( + config=config, + mesh=mesh, + global_batch_size=global_batch_size, + pipeline=pipeline, + is_training=is_training, + ) else: - assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)" + assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain, synthetic)" def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, vae, vae_params): diff --git a/src/maxdiffusion/input_pipeline/synthetic_data_iterator.py b/src/maxdiffusion/input_pipeline/synthetic_data_iterator.py new file mode 100755 index 00000000..928823c5 --- /dev/null +++ b/src/maxdiffusion/input_pipeline/synthetic_data_iterator.py @@ -0,0 +1,492 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +from typing import Dict, Any, Optional +import numpy as np +import jax +import jax.numpy as jnp + +from maxdiffusion import multihost_dataloading, max_logging + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def get_wan_dimension( + config, + pipeline, + config_key: str, + pipeline_path: str = None, + default_value: Any = None +) -> Any: + """ + Get dimension for WAN model with override priority: + 1. Config override (synthetic_override_{config_key}) - for height, width, num_frames + 2. Pipeline path (exact path specified by caller) + 3. Config default + 4. Hardcoded default + + Args: + config: Configuration object + pipeline: WAN Pipeline object + config_key: Key to look up in config + pipeline_path: Exact dotted path in pipeline (e.g., 'transformer.config.in_channels') + default_value: Fallback value if not found elsewhere + """ + # Check overrides for height, width, num_frames (WAN-specific) + if config_key in ['height', 'width', 'num_frames']: + override_key = f'synthetic_override_{config_key}' + try: + value = getattr(config, override_key) + if value is not None: + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using override {config_key}: {value}") + return value + except (AttributeError, ValueError): + pass # Override not set, continue to pipeline/config + + # Check pipeline using exact path if provided + if pipeline is not None and pipeline_path: + try: + # Navigate the dotted path (e.g., 'transformer.config.in_channels') + value = pipeline + for attr in pipeline_path.split('.'): + value = getattr(value, attr) + + if value is not None: + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using {config_key} from pipeline.{pipeline_path}: {value}") + return value + except AttributeError: + pass # Path not available in pipeline + + # Check config - use try/except because config raises ValueError instead of AttributeError + try: + value = getattr(config, config_key) + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using {config_key} from config: {value}") + return value + except (AttributeError, ValueError): + pass # Key not in config, use default + + # Use default + if jax.process_index() == 0: + max_logging.log(f"[WAN] Using default {config_key}: {default_value}") + return default_value + + +def get_flux_dimension( + config, + pipeline, + config_key: str, + pipeline_path: str = None, + default_value: Any = None +) -> Any: + """ + Get dimension for FLUX model with override priority: + 1. Pipeline path (exact path specified by caller) + 2. Config default + 3. Hardcoded default + + Note: FLUX does not support override flags + + Args: + config: Configuration object + pipeline: FLUX Pipeline object + config_key: Key to look up in config + pipeline_path: Exact dotted path in pipeline (e.g., 'vae_scale_factor') + default_value: Fallback value if not found elsewhere + """ + # FLUX does not check overrides - load directly from pipeline/config + + # Check pipeline using exact path if provided + if pipeline is not None and pipeline_path: + try: + # Navigate the dotted path (e.g., 'vae_scale_factor') + value = pipeline + for attr in pipeline_path.split('.'): + value = getattr(value, attr) + + if value is not None: + if jax.process_index() == 0: + max_logging.log(f"[FLUX] Using {config_key} from pipeline.{pipeline_path}: {value}") + return value + except AttributeError: + pass # Path not available in pipeline + + # Check config - use try/except because config raises ValueError instead of AttributeError + try: + value = getattr(config, config_key) + if jax.process_index() == 0: + max_logging.log(f"[FLUX] Using {config_key} from config: {value}") + return value + except (AttributeError, ValueError): + pass # Key not in config, use default + + # Use default + if jax.process_index() == 0: + max_logging.log(f"[FLUX] Using default {config_key}: {default_value}") + return default_value + + +def log_synthetic_config(model_name: str, dimensions: Dict[str, Any], per_host_batch_size: int, is_training: bool, num_samples: Optional[int]): + """Log synthetic data configuration.""" + if jax.process_index() == 0: + info = [ + "=" * 60, + f"{model_name.upper()} Synthetic Data Iterator Configuration:", + f" Per-host batch size: {per_host_batch_size}", + f" Mode: {'Training' if is_training else 'Evaluation'}", + f" Samples per iteration: {num_samples if num_samples else 'Infinite'}", + ] + for key, value in dimensions.items(): + info.append(f" {key}: {value}") + info.append("=" * 60) + max_logging.log("\n".join(info)) + + +# ============================================================================ +# Synthetic Data Source and Iterator +# ============================================================================ + + +class SyntheticDataSource: + """Wrapper for synthetic data that provides iterator interface.""" + + def __init__(self, generate_fn, num_samples, seed): + self.generate_fn = generate_fn + self.num_samples = num_samples + self.seed = seed + self.current_step = 0 + self.rng_key = jax.random.key(seed) + + def __iter__(self): + self.current_step = 0 + self.rng_key = jax.random.key(self.seed) + return self + + def __next__(self): + if self.num_samples is not None and self.current_step >= self.num_samples: + raise StopIteration + + self.rng_key, step_key = jax.random.split(self.rng_key) + data = self.generate_fn(step_key) + self.current_step += 1 + return data + + def as_numpy_iterator(self): + return iter(self) + + +# ============================================================================ +# WAN Model Synthetic Data Generator +# ============================================================================ + + +def _generate_wan_sample(rng_key: jax.Array, dimensions: Dict[str, Any], is_training: bool) -> Dict[str, np.ndarray]: + """Generate a single batch of synthetic data for WAN model.""" + keys = jax.random.split(rng_key, 3) + + per_host_batch_size = dimensions['per_host_batch_size'] + + # Generate latents: (batch, channels, frames, height, width) + latents_shape = ( + per_host_batch_size, + dimensions['num_channels_latents'], + dimensions['num_latent_frames'], + dimensions['latent_height'], + dimensions['latent_width'] + ) + latents = jax.random.normal(keys[0], shape=latents_shape, dtype=jnp.float32) + + # Generate encoder hidden states: (batch, seq_len, embed_dim) + encoder_hidden_states_shape = ( + per_host_batch_size, + dimensions['max_sequence_length'], + dimensions['text_embed_dim'] + ) + encoder_hidden_states = jax.random.normal(keys[1], shape=encoder_hidden_states_shape, dtype=jnp.float32) + + data = { + 'latents': np.array(latents), + 'encoder_hidden_states': np.array(encoder_hidden_states), + } + + # For evaluation, also generate timesteps + if not is_training: + timesteps = jax.random.randint( + keys[2], + shape=(per_host_batch_size,), + minval=0, + maxval=dimensions['num_train_timesteps'], + dtype=jnp.int64 + ) + data['timesteps'] = np.array(timesteps) + + return data + + +def _make_wan_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples): + """Create synthetic data iterator for WAN model.""" + per_host_batch_size = global_batch_size // jax.process_count() + + # Initialize dimensions - explicitly specify pipeline paths for WAN model + height = get_wan_dimension( + config, pipeline, 'height', + pipeline_path=None, # Not in pipeline, use config/override + default_value=480 + ) + width = get_wan_dimension( + config, pipeline, 'width', + pipeline_path=None, # Not in pipeline, use config/override + default_value=832 + ) + num_frames = get_wan_dimension( + config, pipeline, 'num_frames', + pipeline_path=None, # Not in pipeline, use config/override + default_value=81 + ) + + # WAN-specific dimensions from transformer config + max_sequence_length = get_wan_dimension( + config, pipeline, 'max_sequence_length', + pipeline_path='transformer.config.rope_max_seq_len', + default_value=512 + ) + text_embed_dim = get_wan_dimension( + config, pipeline, 'text_embed_dim', + pipeline_path='transformer.config.text_dim', + default_value=4096 + ) + num_channels_latents = get_wan_dimension( + config, pipeline, 'num_channels_latents', + pipeline_path='transformer.config.in_channels', + default_value=16 + ) + + # VAE scale factors from pipeline attributes + vae_scale_factor_spatial = get_wan_dimension( + config, pipeline, 'vae_scale_factor_spatial', + pipeline_path='vae_scale_factor_spatial', + default_value=8 + ) + vae_scale_factor_temporal = get_wan_dimension( + config, pipeline, 'vae_scale_factor_temporal', + pipeline_path='vae_scale_factor_temporal', + default_value=4 + ) + + # Calculate latent dimensions + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + latent_height = height // vae_scale_factor_spatial + latent_width = width // vae_scale_factor_spatial + + # Get num_train_timesteps from scheduler + num_train_timesteps = get_wan_dimension( + config, pipeline, 'num_train_timesteps', + pipeline_path='scheduler.config.num_train_timesteps', + default_value=1000 + ) + # Fallback to scheduler.num_train_timesteps if config doesn't exist + if pipeline is not None and hasattr(pipeline, 'scheduler') and num_train_timesteps == 1000: + try: + num_train_timesteps = pipeline.scheduler.num_train_timesteps + if jax.process_index() == 0: + max_logging.log(f"Using num_train_timesteps from pipeline.scheduler: {num_train_timesteps}") + except AttributeError: + pass + + dimensions = { + 'per_host_batch_size': per_host_batch_size, + 'height': height, + 'width': width, + 'num_frames': num_frames, + 'num_latent_frames': num_latent_frames, + 'latent_height': latent_height, + 'latent_width': latent_width, + 'max_sequence_length': max_sequence_length, + 'text_embed_dim': text_embed_dim, + 'num_channels_latents': num_channels_latents, + 'vae_scale_factor_spatial': vae_scale_factor_spatial, + 'vae_scale_factor_temporal': vae_scale_factor_temporal, + 'num_train_timesteps': num_train_timesteps, + } + + log_synthetic_config('WAN', dimensions, per_host_batch_size, is_training, num_samples) + + # Create generate function with dimensions bound + def generate_fn(rng_key): + return _generate_wan_sample(rng_key, dimensions, is_training) + + data_source = SyntheticDataSource(generate_fn, num_samples, config.seed) + return multihost_dataloading.MultiHostDataLoadIterator(data_source, mesh) + + +# ============================================================================ +# FLUX Model Synthetic Data Generator +# ============================================================================ + + +def _generate_flux_sample(rng_key: jax.Array, dimensions: Dict[str, Any]) -> Dict[str, np.ndarray]: + """Generate a single batch of synthetic data for FLUX model.""" + keys = jax.random.split(rng_key, 4) + + per_host_batch_size = dimensions['per_host_batch_size'] + latent_height = dimensions['latent_height'] + latent_width = dimensions['latent_width'] + latent_seq_len = dimensions['latent_seq_len'] + loss_scaling_factor = 0.1 # magic factor for matching the flux loss. + + # Generate pixel values (packed latents) - should be float16 to match trainer + pixel_values_shape = (per_host_batch_size, latent_seq_len, dimensions['packed_latent_dim']) + pixel_values = jax.random.normal(keys[0], shape=pixel_values_shape, dtype=jnp.float16) + + # Generate text embedding IDs (position encodings) + input_ids_shape = (per_host_batch_size, dimensions['max_sequence_length'], 3) + input_ids = jax.random.normal(keys[1], shape=input_ids_shape, dtype=jnp.float32) + + # Generate text embeddings (T5) + text_embeds_shape = (per_host_batch_size, dimensions['max_sequence_length'], dimensions['t5_embed_dim']) + text_embeds = jax.random.normal(keys[2], shape=text_embeds_shape, dtype=jnp.float32) + + # Generate pooled prompt embeddings (CLIP) + prompt_embeds_shape = (per_host_batch_size, dimensions['pooled_embed_dim']) + prompt_embeds = loss_scaling_factor * jax.random.normal(keys[3], shape=prompt_embeds_shape, dtype=jnp.float32) + + # Generate image position IDs - matching pipeline.prepare_latent_image_ids + # Create base img_ids for single sample (without batch dimension) + img_ids_base = jnp.zeros((latent_height, latent_width, 3), dtype=jnp.float16) + # Channel 0 stays 0 + # Channel 1 = height indices + img_ids_base = img_ids_base.at[..., 1].set(jnp.arange(latent_height)[:, None]) + # Channel 2 = width indices + img_ids_base = img_ids_base.at[..., 2].set(jnp.arange(latent_width)[None, :]) + + # Reshape to (latent_seq_len, 3) + img_ids_base = img_ids_base.reshape(latent_seq_len, 3) + + # Tile for batch dimension + img_ids = jnp.tile(img_ids_base[None, ...], (per_host_batch_size, 1, 1)) + + return { + 'pixel_values': np.array(pixel_values), + 'input_ids': np.array(input_ids), + 'text_embeds': np.array(text_embeds), + 'prompt_embeds': np.array(prompt_embeds), + 'img_ids': np.array(img_ids), + } + + +def _make_flux_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples): + """Create synthetic data iterator for FLUX model.""" + per_host_batch_size = global_batch_size // jax.process_count() + + # Initialize dimensions - explicitly specify pipeline paths for FLUX model + resolution = get_flux_dimension( + config, pipeline, 'resolution', + pipeline_path=None, # Not in pipeline, use config + default_value=512 + ) + max_sequence_length = get_flux_dimension( + config, pipeline, 'max_sequence_length', + pipeline_path=None, # Not in pipeline, use config + default_value=512 + ) + t5_embed_dim = get_flux_dimension( + config, pipeline, 't5_embed_dim', + pipeline_path='text_encoder_2.config.d_model', # T5 model dimension + default_value=4096 + ) + pooled_embed_dim = get_flux_dimension( + config, pipeline, 'pooled_embed_dim', + pipeline_path='text_encoder.config.projection_dim', # CLIP projection dimension + default_value=768 + ) + vae_scale_factor = get_flux_dimension( + config, pipeline, 'vae_scale_factor', + pipeline_path='vae_scale_factor', # Direct pipeline attribute + default_value=8 + ) + + # Calculate packed latent dimensions + latent_height = math.ceil(resolution // (vae_scale_factor * 2)) + latent_width = math.ceil(resolution // (vae_scale_factor * 2)) + latent_seq_len = latent_height * latent_width + packed_latent_dim = 64 # 16 channels * 2 * 2 packing + + dimensions = { + 'per_host_batch_size': per_host_batch_size, + 'max_sequence_length': max_sequence_length, + 't5_embed_dim': t5_embed_dim, + 'pooled_embed_dim': pooled_embed_dim, + 'resolution': resolution, + 'latent_height': latent_height, + 'latent_width': latent_width, + 'latent_seq_len': latent_seq_len, + 'packed_latent_dim': packed_latent_dim, + } + + log_synthetic_config('FLUX', dimensions, per_host_batch_size, is_training, num_samples) + + # Create generate function with dimensions bound + def generate_fn(rng_key): + return _generate_flux_sample(rng_key, dimensions) + + data_source = SyntheticDataSource(generate_fn, num_samples, config.seed) + return multihost_dataloading.MultiHostDataLoadIterator(data_source, mesh) + + +# ============================================================================ +# Public API +# ============================================================================ + + +def make_synthetic_iterator(config, mesh, global_batch_size, pipeline=None, is_training=True): + """ + Create a synthetic data iterator for the specified model. + + Args: + config: Configuration object with model_name + mesh: JAX mesh for sharding + global_batch_size: Total batch size across all devices + pipeline: Optional pipeline object to extract dimensions from + is_training: Whether this is for training or evaluation + + Returns: + MultiHostDataLoadIterator wrapping the synthetic data source + """ + num_samples = getattr(config, 'synthetic_num_samples', None) + + try: + model_name = getattr(config, 'model_name', None) + if model_name in ('wan2.1', 'wan2.2'): + return _make_wan_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples) + except (AttributeError, ValueError): + pass + try: + model_name = getattr(config, 'flux_name', None) + if model_name in ('flux', 'flux-dev', 'flux-schnell'): + return _make_flux_synthetic_iterator(config, mesh, global_batch_size, pipeline, is_training, num_samples) + except (AttributeError, ValueError): + pass + + raise ValueError( + f"No synthetic iterator implemented for model." + f"Supported models: wan2.1, wan2.2, flux, flux-dev, flux-schnell" + ) diff --git a/src/maxdiffusion/loaders/__init__.py b/src/maxdiffusion/loaders/__init__.py index 2c9e973d..48d204db 100644 --- a/src/maxdiffusion/loaders/__init__.py +++ b/src/maxdiffusion/loaders/__init__.py @@ -14,3 +14,4 @@ from .lora_pipeline import StableDiffusionLoraLoaderMixin from .flux_lora_pipeline import FluxLoraLoaderMixin +from .wan_lora_nnx_loader import Wan2_1NNXLoraLoader, Wan2_2NNXLoraLoader diff --git a/src/maxdiffusion/loaders/flux_lora_pipeline.py b/src/maxdiffusion/loaders/flux_lora_pipeline.py index 5f449ee9..56844db7 100644 --- a/src/maxdiffusion/loaders/flux_lora_pipeline.py +++ b/src/maxdiffusion/loaders/flux_lora_pipeline.py @@ -22,7 +22,6 @@ class FluxLoraLoaderMixin(LoRABaseMixin): - _lora_lodable_modules = ["transformer", "text_encoder"] def load_lora_weights( @@ -98,7 +97,6 @@ def _get_lora_layer(cls, module_path, module, rank, network_alphas, adapter_name @classmethod @validate_hf_hub_args def lora_state_dict(cls, pretrained_model_name_or_path: str, **kwargs): - cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) diff --git a/src/maxdiffusion/loaders/lora_conversion_utils.py b/src/maxdiffusion/loaders/lora_conversion_utils.py index 5f9e72a6..96bdb0c8 100644 --- a/src/maxdiffusion/loaders/lora_conversion_utils.py +++ b/src/maxdiffusion/loaders/lora_conversion_utils.py @@ -391,7 +391,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] if not is_sparse: # down_weight is copied to each split - ait_sd.update({k: down_weight for k in ait_down_keys}) + ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 @@ -534,7 +534,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] # down_weight is copied to each split - ait_sd.update({k: down_weight for k in ait_down_keys}) + ait_sd.update(dict.fromkeys(ait_down_keys, down_weight)) # up_weight is split to each split ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 @@ -608,3 +608,98 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") return new_state_dict + + +def preprocess_wan_lora_dict(state_dict): + """ + Preprocesses WAN LoRA dict to convert diff_m to modulation.diff. + """ + new_d = {} + for k, v in state_dict.items(): + if k.endswith(".diff_m"): + new_k = k.removesuffix(".diff_m") + ".modulation.diff" + new_d[new_k] = v + else: + new_d[k] = v + return new_d + + +def translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=False): + """ + Translates WAN NNX path to Diffusers/LoRA keys. + Verified against wan_utils.py mappings. + """ + + # --- 1. Embeddings (Exact Matches) --- + if nnx_path_str == "condition_embedder.text_embedder.linear_1": + return "diffusion_model.text_embedding.0" + if nnx_path_str == "condition_embedder.text_embedder.linear_2": + return "diffusion_model.text_embedding.2" + if nnx_path_str == "condition_embedder.time_embedder.linear_1": + return "diffusion_model.time_embedding.0" + if nnx_path_str == "condition_embedder.time_embedder.linear_2": + return "diffusion_model.time_embedding.2" + if nnx_path_str == "condition_embedder.image_embedder.norm1.layer_norm": + return "diffusion_model.img_emb.proj.0" + if nnx_path_str == "condition_embedder.image_embedder.ff.net_0": + return "diffusion_model.img_emb.proj.1" + if nnx_path_str == "condition_embedder.image_embedder.ff.net_2": + return "diffusion_model.img_emb.proj.3" + if nnx_path_str == "condition_embedder.image_embedder.norm2.layer_norm": + return "diffusion_model.img_emb.proj.4" + if nnx_path_str == "patch_embedding": + return "diffusion_model.patch_embedding" + if nnx_path_str == "proj_out": + return "diffusion_model.head.head" + if nnx_path_str == "scale_shift_table": + return "diffusion_model.head.modulation" + if nnx_path_str == "condition_embedder.time_proj": + return "diffusion_model.time_projection.1" + + # --- 2. Map NNX Suffixes to LoRA Suffixes --- + suffix_map = { + # Self Attention (attn1) + "attn1.query": "self_attn.q", + "attn1.key": "self_attn.k", + "attn1.value": "self_attn.v", + "attn1.proj_attn": "self_attn.o", + # Self Attention Norms (QK Norm) + "attn1.norm_q": "self_attn.norm_q", + "attn1.norm_k": "self_attn.norm_k", + # Cross Attention (attn2) + "attn2.query": "cross_attn.q", + "attn2.key": "cross_attn.k", + "attn2.value": "cross_attn.v", + "attn2.proj_attn": "cross_attn.o", + # Cross Attention Norms (QK Norm) + "attn2.norm_q": "cross_attn.norm_q", + "attn2.norm_k": "cross_attn.norm_k", + # Cross Attention img + "attn2.add_k_proj": "cross_attn.k_img", + "attn2.add_v_proj": "cross_attn.v_img", + "attn2.norm_added_k": "cross_attn.norm_k_img", + # Feed Forward (ffn) + "ffn.act_fn.proj": "ffn.0", # Up proj + "ffn.proj_out": "ffn.2", # Down proj + # Global Norms & Modulation + "norm2.layer_norm": "norm3", + "adaln_scale_shift_table": "modulation", + "proj_out": "head.head", + } + + # --- 3. Translation Logic --- + if scan_layers: + # Scanned Pattern: "blocks.attn1.query" -> "diffusion_model.blocks.{}.self_attn.q" + if nnx_path_str.startswith("blocks."): + inner_suffix = nnx_path_str[len("blocks.") :] + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{{}}.{suffix_map[inner_suffix]}" + else: + # Unscanned Pattern: "blocks.0.attn1.query" -> "diffusion_model.blocks.0.self_attn.q" + m = re.match(r"^blocks\.(\d+)\.(.+)$", nnx_path_str) + if m: + idx, inner_suffix = m.group(1), m.group(2) + if inner_suffix in suffix_map: + return f"diffusion_model.blocks.{idx}.{suffix_map[inner_suffix]}" + + return None diff --git a/src/maxdiffusion/loaders/lora_pipeline.py b/src/maxdiffusion/loaders/lora_pipeline.py index 7feb20ca..2d8c1c75 100644 --- a/src/maxdiffusion/loaders/lora_pipeline.py +++ b/src/maxdiffusion/loaders/lora_pipeline.py @@ -134,7 +134,6 @@ def rename_for_interceptor(params_keys, network_alphas, adapter_name): @classmethod def make_lora_interceptor(cls, params, rank, network_alphas, adapter_name): - network_alphas_for_interceptor = {} unet_lora_keys = flax.traverse_util.flatten_dict(params["unet"]).keys() diff --git a/src/maxdiffusion/loaders/wan_lora_nnx_loader.py b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py new file mode 100644 index 00000000..a34c0f1a --- /dev/null +++ b/src/maxdiffusion/loaders/wan_lora_nnx_loader.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX-based LoRA loader for WAN models.""" + +from flax import nnx +from .lora_base import LoRABaseMixin +from .lora_pipeline import StableDiffusionLoraLoaderMixin +from ..models import lora_nnx +from .. import max_logging +from . import lora_conversion_utils + + +class Wan2_1NNXLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.1 model. + Assumes WAN pipeline contains 'transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + transformer_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + dtype: str = "float32", + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + + def translate_fn(nnx_path_str): + return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + if hasattr(pipeline, "transformer") and transformer_weight_name: + max_logging.log(f"Merging LoRA into transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=transformer_weight_name, **kwargs) + h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict) + merge_fn(pipeline.transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("transformer not found or no weight name provided for LoRA.") + + return pipeline + + +class Wan2_2NNXLoraLoader(LoRABaseMixin): + """ + Handles loading LoRA weights into NNX-based WAN 2.2 model. + Assumes WAN pipeline contains 'high_noise_transformer' and 'low_noise_transformer' + attributes that are NNX Modules. + """ + + def load_lora_weights( + self, + pipeline: nnx.Module, + lora_model_path: str, + high_noise_weight_name: str, + low_noise_weight_name: str, + rank: int, + scale: float = 1.0, + scan_layers: bool = False, + dtype: str = "float32", + **kwargs, + ): + """ + Merges LoRA weights into the pipeline from a checkpoint. + """ + lora_loader = StableDiffusionLoraLoaderMixin() + + merge_fn = lora_nnx.merge_lora_for_scanned if scan_layers else lora_nnx.merge_lora + + def translate_fn(nnx_path_str: str): + return lora_conversion_utils.translate_wan_nnx_path_to_diffusers_lora(nnx_path_str, scan_layers=scan_layers) + + # Handle high noise model + if hasattr(pipeline, "high_noise_transformer") and high_noise_weight_name: + max_logging.log(f"Merging LoRA into high_noise_transformer with rank={rank}") + h_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=high_noise_weight_name, **kwargs) + h_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(h_state_dict) + merge_fn(pipeline.high_noise_transformer, h_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("high_noise_transformer not found or no weight name provided for LoRA.") + + # Handle low noise model + if hasattr(pipeline, "low_noise_transformer") and low_noise_weight_name: + max_logging.log(f"Merging LoRA into low_noise_transformer with rank={rank}") + l_state_dict, _ = lora_loader.lora_state_dict(lora_model_path, weight_name=low_noise_weight_name, **kwargs) + l_state_dict = lora_conversion_utils.preprocess_wan_lora_dict(l_state_dict) + merge_fn(pipeline.low_noise_transformer, l_state_dict, rank, scale, translate_fn, dtype=dtype) + else: + max_logging.log("low_noise_transformer not found or no weight name provided for LoRA.") + + return pipeline diff --git a/src/maxdiffusion/max_logging.py b/src/maxdiffusion/max_logging.py index 32ac3d8f..2edb43f4 100644 --- a/src/maxdiffusion/max_logging.py +++ b/src/maxdiffusion/max_logging.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Stub for logging utilities. Right now just meant to avoid raw prints""" diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fb7266a1..04b3869f 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -1,19 +1,19 @@ # ruff: noqa """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=bare-except, consider-using-generator """ Common Max Utils needed by multiple modules""" @@ -268,17 +268,30 @@ def create_device_mesh(config, devices=None, logging=True): max_logging.log(f"Devices: {devices} (num_devices: {num_devices})") multi_slice_env = num_slices > 1 - - dcn_parallelism = [ - config.dcn_data_parallelism, - config.dcn_fsdp_parallelism, - config.dcn_tensor_parallelism, - ] - ici_parallelism = [ - config.ici_data_parallelism, - config.ici_fsdp_parallelism, - config.ici_tensor_parallelism, - ] + if "dcn_context_parallelism" in config.get_keys() and "ici_context_parallelism" in config.get_keys(): + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_context_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_context_parallelism, + config.ici_tensor_parallelism, + ] + else: + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_tensor_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_tensor_parallelism, + ] # Find possible unspecified parallelisms ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") @@ -502,17 +515,20 @@ def get_flash_block_sizes(config): flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: attention_is_tokamax = "tokamax" in config.attention - user_block_sizes:Dict[str, int] = config.flash_block_sizes + user_block_sizes: Dict[str, int] = config.flash_block_sizes if attention_is_tokamax: - max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." - "Hence following flash block properties specified will be ignored:" - f"block_q: {user_block_sizes['block_q']}," - f"block_q_dq: {user_block_sizes.get('block_q_dq')}," - f"block_kv_dq: {user_block_sizes.get('block_kv_dq')}," - f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}" - ) + max_logging.log( + "Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." + "Hence following flash block properties specified will be ignored:" + f"block_q: {user_block_sizes['block_q']}," + f"block_q_dq: {user_block_sizes.get('block_q_dq')}," + f"block_kv_dq: {user_block_sizes.get('block_kv_dq')}," + f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}" + ) flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"], + block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) + if attention_is_tokamax + else user_block_sizes["block_q"], block_kv_compute=user_block_sizes["block_kv_compute"], block_kv=user_block_sizes["block_kv"], block_q_dkv=user_block_sizes["block_q_dkv"], @@ -541,7 +557,6 @@ def get_memory_allocations(): def get_live_arrays(): - backend = jax.extend.backend.get_backend() live_arrays = backend.live_arrays() diff --git a/src/maxdiffusion/maxdiffusion_utils.py b/src/maxdiffusion/maxdiffusion_utils.py index b9b1abdc..32eed7f4 100644 --- a/src/maxdiffusion/maxdiffusion_utils.py +++ b/src/maxdiffusion/maxdiffusion_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import io from PIL import Image @@ -42,7 +42,6 @@ def load_sdxllightning_unet(config, pipeline, params): def maybe_load_sdxl_lora(config, pipeline, params): - def _noop_interceptor(next_fn, args, kwargs, context): return next_fn(*args, **kwargs) diff --git a/src/maxdiffusion/models/__init__.py b/src/maxdiffusion/models/__init__.py index 96a6f128..7ff8fd8f 100644 --- a/src/maxdiffusion/models/__init__.py +++ b/src/maxdiffusion/models/__init__.py @@ -24,7 +24,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - from .controlnet_flax import FlaxControlNetModel from .unet_2d_condition_flax import FlaxUNet2DConditionModel from .vae_flax import FlaxAutoencoderKL diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 2982e19e..0e3e24e5 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -20,7 +20,6 @@ from flax import nnx import jax from jax.ad_checkpoint import checkpoint_name -from jax.sharding import PartitionSpec import jax.numpy as jnp from jax.experimental import shard_map from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask @@ -78,8 +77,11 @@ def _reshape_data_from_cudnn_flash(tensor): def _reshape_data_for_cudnn_flash(tensor, heads): # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) - batch, seq, heads_and_dim_head = tensor.shape - tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + if len(tensor.shape) == 3: + batch, seq, dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, heads, dim_head // heads) + else: + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) return tensor @@ -89,7 +91,8 @@ def _reshape_batch_dim_to_heads(tensor, heads): tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _reshape_heads_to_batch_dim(tensor, heads): @@ -102,8 +105,8 @@ def _reshape_heads_to_batch_dim(tensor, heads): else: batch_size, head_size, seq_len, head_dim = tensor.shape reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) - - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _reshape_heads_to_head_dim(tensor): @@ -112,7 +115,8 @@ def _reshape_heads_to_head_dim(tensor): b, h, s, d = tensor.shape tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d)) - return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names) def _unflatten_heads(tensor, heads): @@ -173,17 +177,20 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): return tensor, kv_size, seq_len -def convert_to_tokamax_splash_config( block_sizes: BlockSizes, - q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, - residual_checkpoint_name: str | None = None, - attn_logits_soft_cap: float | None = None, - fuse_reciprocal: bool = True, - use_base2_exp: bool = False, - max_logit_const: float | None = None, - interpret: bool = False, - dq_reduction_steps: int | None = None) -> tokamax_splash_attention_kernel.SplashConfig: + +def convert_to_tokamax_splash_config( + block_sizes: BlockSizes, + q_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + k_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + v_layout: tokamax_splash_attention_kernel.QKVLayout = tokamax_splash_attention_kernel.QKVLayout.HEAD_DIM_MINOR, + residual_checkpoint_name: str | None = None, + attn_logits_soft_cap: float | None = None, + fuse_reciprocal: bool = True, + use_base2_exp: bool = False, + max_logit_const: float | None = None, + interpret: bool = False, + dq_reduction_steps: int | None = None, +) -> tokamax_splash_attention_kernel.SplashConfig: assert block_sizes.use_fused_bwd_kernel, "Tokamax Splash attention only supports fused bwd kernel." return tokamax_splash_attention_kernel.SplashConfig( block_q=block_sizes.block_q, @@ -192,7 +199,7 @@ def convert_to_tokamax_splash_config( block_sizes: BlockSizes, block_q_dkv=block_sizes.block_q_dkv, block_kv_dkv=block_sizes.block_kv_dkv, block_kv_dkv_compute=block_sizes.block_kv_dkv_compute, - block_q_dq= None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, + block_q_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_q_dq, block_kv_dq=None if block_sizes.use_fused_bwd_kernel else block_sizes.block_kv_dq, use_fused_bwd_kernel=block_sizes.use_fused_bwd_kernel, q_layout=q_layout, @@ -248,7 +255,7 @@ def _tpu_flash_attention( block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, ) - num_fsdp_shards = mesh.shape["fsdp"] + num_context_shards = mesh.shape["context"] query = _reshape_data_for_flash(query, heads) key = _reshape_data_for_flash(key, heads) value = _reshape_data_for_flash(value, heads) @@ -319,7 +326,9 @@ def wrap_flash_attention(query, key, value): # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. if attention_kernel == "tokamax_flash": - mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),) + mask = tokamax_splash_attention_mask.FullMask( + _shape=(query.shape[2], key.shape[2]), + ) splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( mask=mask, q_seq_shards=1, # the sizes of the axis is sharding over seq_len @@ -333,7 +342,7 @@ def wrap_flash_attention(query, key, value): q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, save_residuals=True if attention_kernel == "ring" else False, - residual_checkpoint_name=residual_checkpoint_name + residual_checkpoint_name=residual_checkpoint_name, ) vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) @@ -342,21 +351,21 @@ def wrap_flash_attention(query, key, value): if attention_kernel in ["flash", "tokamax_flash"]: attention_output = vmapped_splash(query, key, value, segment_ids) else: - if num_fsdp_shards > 1: + if num_context_shards > 1: out, (lse,) = vmapped_splash(query, key, value, segment_ids) m = lse.astype(jnp.float32) l = jnp.exp(lse - m) o = out.astype(jnp.float32) * l[..., None] - perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] + perm = [(j, (j + 1) % num_context_shards) for j in range(num_context_shards)] - k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) - v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) + k1 = jax.lax.ppermute(key, axis_name="context", perm=perm) + v1 = jax.lax.ppermute(value, axis_name="context", perm=perm) def ring_scan_body(carry, _): m, l, o, k_current, v_current = carry - k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) - v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) + k_next = jax.lax.ppermute(k_current, axis_name="context", perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name="context", perm=perm) out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) @@ -375,21 +384,22 @@ def ring_scan_body(carry, _): return (m, l, o, k_next, v_next), None initial_carry = (m, l, o, k1, v1) - (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) + (m_final, l_final, o_final, _, _), _ = jax.lax.scan( + ring_scan_body, initial_carry, None, length=num_context_shards - 1 + ) attention_output = o_final / l_final[..., None] else: - raise ValueError("ring attention requires fsdp > 1") - + raise ValueError("ring attention requires context > 1") return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) - devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] + devices_in_data_context = mesh.shape["data"] * mesh.shape["context"] # This warning might show up when doing model eval for example, when calculating model flops # and that is expected. - if not (query.shape[0] / devices_in_data_fsdp).is_integer(): + if not (query.shape[0] / devices_in_data_context).is_integer(): max_logging.log( - "Warning, batch dimension should be shardable among the devices in data and fsdp" - f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" + "Warning, batch dimension should be shardable among the devices in data and context" + f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}" ) x = wrap_flash_attention(query, key, value) x = _reshape_heads_to_head_dim(x) @@ -481,24 +491,12 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m key = _reshape_data_for_cudnn_flash(key, heads) value = _reshape_data_for_cudnn_flash(value, heads) - cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) - axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) - - query = nn.with_logical_constraint(query, axis_names) - key = nn.with_logical_constraint(key, axis_names) - value = nn.with_logical_constraint(value, axis_names) - - @functools.partial( - shard_map.shard_map, - mesh=mesh, - in_specs=(axis_names, axis_names, axis_names), - out_specs=axis_names, - check_rep=False, - ) - def wrap_flash_attention(query, key, value): - return jax.vmap(dpa_layer)(query, key, value, mask=None) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD, D_KV)) + query = jax.lax.with_sharding_constraint(query, axis_names) + key = jax.lax.with_sharding_constraint(key, axis_names) + value = jax.lax.with_sharding_constraint(value, axis_names) - out = wrap_flash_attention(query, key, value) + out = dpa_layer(query, key, value, mask=None) return _reshape_data_from_cudnn_flash(out) @@ -559,7 +557,16 @@ def _apply_attention( ) elif attention_kernel == "ring": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, + query, + key * scale, + value, + heads, + mesh, + axis_names_q, + axis_names_kv, + flash_block_sizes, + dtype, + attention_kernel, mask_padding_tokens=mask_padding_tokens, ) elif attention_kernel == "cudnn_flash_te": @@ -671,9 +678,21 @@ def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]: return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype) + # New Class for Wan I2V class NNXSimpleFeedForward(nnx.Module): - def __init__(self, rngs: nnx.Rngs, dim: int, dim_out: Optional[int] = None, mult: int = 4, activation_fn: str = "gelu", dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: Optional[jax.lax.Precision] = None): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + activation_fn: str = "gelu", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: Optional[jax.lax.Precision] = None, + ): inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim self.net_0 = nnx.Linear( @@ -706,6 +725,7 @@ def __call__(self, hidden_states: Array) -> Array: hidden_states = self.net_2(hidden_states) return hidden_states + class NNXAttentionOp(nnx.Module): def __init__( @@ -730,7 +750,25 @@ def __init__( ): self.dpa_layer = None if attention_kernel == "cudnn_flash_te": - raise NotImplementedError(f"{self} has not been tested with {attention_kernel}") + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + + jax.config.update("jax_use_shardy_partitioner", False) + + dpa_layer = DotProductAttention( + head_dim=dim_head, + num_attention_heads=heads, + num_gqa_groups=heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=dtype, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=scale, + transpose_batch_sequence=False, + ) + variables = {} + self.dpa_layer = functools.partial(dpa_layer.apply, variables) self.mesh = mesh self.scale = scale @@ -795,7 +833,9 @@ def setup(self): if self.attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - self.dpa_layer = DotProductAttention( + jax.config.update("jax_use_shardy_partitioner", False) + + dpa_layer = DotProductAttention( head_dim=self.dim_head, num_attention_heads=self.heads, num_gqa_groups=self.heads, @@ -809,6 +849,8 @@ def setup(self): scale_factor=self.scale, transpose_batch_sequence=False, ) + variables = {} + self.dpa_layer = functools.partial(dpa_layer.apply, variables) def apply_attention(self, query: Array, key: Array, value: Array, attention_mask: Array = None): return _apply_attention( @@ -864,12 +906,9 @@ def __init__( mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, enable_jax_named_scopes: bool = False, - added_kv_proj_dim: Optional[int] = None, # New for I2V - image_seq_len: Optional[int] = None, # New for I2V + added_kv_proj_dim: Optional[int] = None, # New for I2V + image_seq_len: Optional[int] = None, # New for I2V ): - if attention_kernel == "cudnn_flash_te": - raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") - if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") self.dim_head = dim_head @@ -889,8 +928,8 @@ def __init__( else: axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) - self.added_kv_proj_dim = added_kv_proj_dim # New for I2V - self.image_seq_len = image_seq_len # New for I2V + self.added_kv_proj_dim = added_kv_proj_dim # New for I2V + self.image_seq_len = image_seq_len # New for I2V self.attention_op = NNXAttentionOp( mesh=mesh, @@ -1006,23 +1045,35 @@ def __init__( self.norm_added_k = nnx.data(None) if self.added_kv_proj_dim is not None: self.add_k_proj = nnx.Linear( - self.added_kv_proj_dim, self.inner_dim, rngs=rngs, - dtype=dtype, param_dtype=weights_dtype, precision=precision, + self.added_kv_proj_dim, + self.inner_dim, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, ("embed",), ), ) self.add_v_proj = nnx.Linear( - self.added_kv_proj_dim, self.inner_dim, rngs=rngs, - dtype=dtype, param_dtype=weights_dtype, precision=precision, + self.added_kv_proj_dim, + self.inner_dim, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, ("embed",), ), ) self.norm_added_k = nnx.RMSNorm( - num_features=self.inner_dim, rngs=rngs, epsilon=eps, dtype=dtype, param_dtype=weights_dtype, + num_features=self.inner_dim, + rngs=rngs, + epsilon=eps, + dtype=dtype, + param_dtype=weights_dtype, scale_init=nnx.with_partitioning( nnx.initializers.ones, ("norm",), @@ -1058,8 +1109,9 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) dtype = hidden_states.dtype is_self_attention = encoder_hidden_states is None if encoder_hidden_states is None: @@ -1120,10 +1172,10 @@ def __call__( # Use the passed encoder_attention_mask (created in embeddings_flax.py) if using Flash Attention # It contains the image mask: [1]*257 + [0]*127 for 257 real image tokens padded to 384 if encoder_attention_mask is not None: - encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len] + encoder_attention_mask_img = encoder_attention_mask[:, :padded_img_len] else: - # Fallback: no mask means treat all as valid (for dot product attention) - encoder_attention_mask_img = None + # Fallback: no mask means treat all as valid (for dot product attention) + encoder_attention_mask_img = None else: # If no image_seq_len is specified, treat all as text encoder_hidden_states_img = None @@ -1134,7 +1186,7 @@ def __call__( with self.conditional_named_scope("attn_q_norm"): query_proj_text = self.norm_q(query_proj_raw) else: - query_proj_text = query_proj_raw + query_proj_text = query_proj_raw # Text K/V with self.conditional_named_scope("proj_key"): @@ -1163,13 +1215,14 @@ def __call__( value_proj_img = checkpoint_name(value_proj_img, "value_proj_img") query_proj_img = checkpoint_name(query_proj_img, "query_proj_img") - # Attention - tensors are (B, S, D) with self.conditional_named_scope("cross_attn_text_apply"): attn_output_text = self.attention_op.apply_attention(query_proj_text, key_proj_text, value_proj_text) with self.conditional_named_scope("cross_attn_img_apply"): # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens - attn_output_img = self.attention_op.apply_attention(query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img) + attn_output_img = self.attention_op.apply_attention( + query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img + ) attn_output = attn_output_text + attn_output_img else: diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 21c67e10..41afa3b4 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -249,10 +249,32 @@ def get_1d_rotary_pos_embed( out = jnp.exp(1j * freqs) return out + class NNXWanImageEmbedding(nnx.Module): - def __init__(self, rngs: nnx.Rngs, in_features: int, out_features: int, dtype: jnp.dtype, weights_dtype: jnp.dtype, precision: jax.lax.Precision, pos_embed_seq_len=None, alignment: int = 128, flash_min_seq_length: int = 4096): + + def __init__( + self, + rngs: nnx.Rngs, + in_features: int, + out_features: int, + dtype: jnp.dtype, + weights_dtype: jnp.dtype, + precision: jax.lax.Precision, + pos_embed_seq_len=None, + alignment: int = 128, + flash_min_seq_length: int = 4096, + ): self.norm1 = FP32LayerNorm(rngs=rngs, dim=in_features, elementwise_affine=True, eps=1e-6) - self.ff = NNXSimpleFeedForward(rngs=rngs, dim=in_features, dim_out=out_features, mult=1, activation_fn="gelu", dtype=dtype, weights_dtype=weights_dtype, precision=precision) + self.ff = NNXSimpleFeedForward( + rngs=rngs, + dim=in_features, + dim_out=out_features, + mult=1, + activation_fn="gelu", + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + ) self.norm2 = FP32LayerNorm(rngs=rngs, dim=out_features, elementwise_affine=True, eps=1e-6) self.alignment = alignment self.flash_min_seq_length = flash_min_seq_length @@ -271,14 +293,14 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j # Apply pos_embed to the original sequence length hidden_states = hidden_states.at[:, :add_len, :].add(self.pos_embed.value[:, :add_len, :]) if current_seq_len > pe_len: - print(f"[WARN] Input seq_len {current_seq_len} > pos_embed len {pe_len}") + print(f"[WARN] Input seq_len {current_seq_len} > pos_embed len {pe_len}") hidden_states = self.norm1(hidden_states) hidden_states = self.ff(hidden_states) hidden_states = self.norm2(hidden_states) # hidden_states shape: (B, current_seq_len, out_features) B, current_seq_len, D_out = hidden_states.shape - use_flash_attn = current_seq_len>=self.flash_min_seq_length + use_flash_attn = current_seq_len >= self.flash_min_seq_length if use_flash_attn: # --- Dynamic Padding to nearest multiple of self.alignment --- @@ -291,13 +313,13 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> tuple[jax.Array, j attention_mask = jnp.ones((B, current_seq_len), dtype=jnp.int32) if current_seq_len < target_seq_len: - padding_size = target_seq_len - current_seq_len - padding = jnp.zeros((B, padding_size, D_out), dtype=hidden_states.dtype) - hidden_states = jnp.concatenate([hidden_states, padding], axis=1) + padding_size = target_seq_len - current_seq_len + padding = jnp.zeros((B, padding_size, D_out), dtype=hidden_states.dtype) + hidden_states = jnp.concatenate([hidden_states, padding], axis=1) - # Extend mask with zeros for padded positions - padding_mask = jnp.zeros((B, padding_size), dtype=jnp.int32) - attention_mask = jnp.concatenate([attention_mask, padding_mask], axis=1) + # Extend mask with zeros for padded positions + padding_mask = jnp.zeros((B, padding_size), dtype=jnp.int32) + attention_mask = jnp.concatenate([attention_mask, padding_mask], axis=1) if not use_flash_attn: attention_mask = None return hidden_states, attention_mask diff --git a/src/maxdiffusion/models/flux/__init__.py b/src/maxdiffusion/models/flux/__init__.py index 84dd0f15..217c0ac8 100644 --- a/src/maxdiffusion/models/flux/__init__.py +++ b/src/maxdiffusion/models/flux/__init__.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from .transformers.transformer_flux_flax import FluxTransformer2DModel diff --git a/src/maxdiffusion/models/flux/transformers/__init__.py b/src/maxdiffusion/models/flux/transformers/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/flux/transformers/__init__.py +++ b/src/maxdiffusion/models/flux/transformers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 7f63da67..814e21ea 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Tuple import jax @@ -180,7 +180,6 @@ class FluxTransformerBlock(nn.Module): attention_kernel: str = "dot_product" def setup(self): - self.img_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) self.txt_norm1 = AdaLayerNormZero(self.dim, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision) @@ -203,29 +202,27 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.img_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) + self.img_mlp = nn.Sequential([ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ]) self.txt_norm2 = nn.LayerNorm( use_bias=False, @@ -234,29 +231,27 @@ def setup(self): dtype=self.dtype, param_dtype=self.weights_dtype, ) - self.txt_mlp = nn.Sequential( - [ - nn.Dense( - int(self.dim * self.mlp_ratio), - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - nn.gelu, - nn.Dense( - self.dim, - use_bias=True, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), - dtype=self.dtype, - param_dtype=self.weights_dtype, - precision=self.precision, - ), - ] - ) + self.txt_mlp = nn.Sequential([ + nn.Dense( + int(self.dim * self.mlp_ratio), + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "mlp")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + nn.gelu, + nn.Dense( + self.dim, + use_bias=True, + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, (None,)), + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ), + ]) # let chunk size default to None self._chunk_size = None diff --git a/src/maxdiffusion/models/flux/util.py b/src/maxdiffusion/models/flux/util.py index 8f7d0bf5..a4f665c6 100644 --- a/src/maxdiffusion/models/flux/util.py +++ b/src/maxdiffusion/models/flux/util.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # copied from https://github.com/ml-gde/jflux/blob/main/jflux/util.py import os diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index 9162fbcb..18e5c7e6 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from enum import Enum, auto diff --git a/src/maxdiffusion/models/lora.py b/src/maxdiffusion/models/lora.py index 82d32e80..88a2b92a 100644 --- a/src/maxdiffusion/models/lora.py +++ b/src/maxdiffusion/models/lora.py @@ -1,17 +1,17 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Union, Tuple, Optional diff --git a/src/maxdiffusion/models/lora_nnx.py b/src/maxdiffusion/models/lora_nnx.py new file mode 100644 index 00000000..97456630 --- /dev/null +++ b/src/maxdiffusion/models/lora_nnx.py @@ -0,0 +1,505 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import re +import torch +import jax +from jax import dlpack +import jax.numpy as jnp +from flax import nnx +from .. import max_logging +import numpy as np + +# ----------------------------------------------------------------------------- +# JIT Helpers (The Fix for Sharding & Device-Side Computation) +# ----------------------------------------------------------------------------- + + +@jax.jit +def _compute_and_add_single_jit(kernel, bias, down, up, scale, w_diff, b_diff): + """ + Applies LoRA + Weight Diff + Bias Diff on device. + """ + # 1. Apply LoRA (if valid) + if down is not None and up is not None: + # down: (Rank, In), up: (Out, Rank) -> Result: (In, Out) + # Note: We reshape to kernel shape to handle 1x1 convs + delta = (down.T @ up.T).reshape(kernel.shape) + kernel = kernel + (delta * scale).astype(kernel.dtype) + + # 2. Apply Full Weight Diff (if valid) + if w_diff is not None: + kernel = kernel + w_diff.astype(kernel.dtype) + + # 3. Apply Bias Diff (if valid and bias exists) + if bias is not None and b_diff is not None: + bias = bias + b_diff.astype(bias.dtype) + + return kernel, bias + + +@jax.jit +def _compute_and_add_scanned_jit(kernel, downs, ups, alphas, global_scale, w_diffs=None, b_diffs=None, bias=None): + """ + Applies scanned LoRA + Diffs. + """ + # 1. Apply LoRA + if downs is not None and ups is not None: + rank = downs.shape[1] + scales = global_scale * alphas / rank + # Batch Matmul: (L, In, Out) + delta = jnp.matmul(jnp.swapaxes(downs, 1, 2), jnp.swapaxes(ups, 1, 2)) + delta = (delta * scales).astype(kernel.dtype) + kernel = kernel + delta.reshape(kernel.shape) + + # 2. Apply Scanned Weight Diffs (L, ...) + if w_diffs is not None: + kernel = kernel + w_diffs.astype(kernel.dtype) + + # 3. Apply Scanned Bias Diffs (L, ...) + # Note: Scanned bias is usually shape (L, Out) + if bias is not None and b_diffs is not None: + bias = bias + b_diffs.astype(bias.dtype) + + return kernel, bias + + +# ----------------------------------------------------------------------------- + + +def _to_jax_array(v, dtype): + jax_dtype = jnp.dtype(dtype) + if isinstance(v, torch.Tensor): + return dlpack.from_dlpack(v).astype(jax_dtype) + return jnp.array(v, dtype=jax_dtype) + + +def parse_lora_dict(state_dict, dtype): + """ + Helper to parse state_dict into structured params including diffs. + Supports keys in original WAN LoRA format, like: + '...lora_A.weight', '...lora_up.weight', '...alpha', + as well as '.diff' and '.diff_b' for weight and bias fine-tuning. + Supports ComfyUI and AI toolkit lora formats. + """ + lora_params = {} + for k, v in state_dict.items(): + # Alpha + if k.endswith(".alpha"): + key_base = k[: -len(".alpha")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["alpha"] = _to_jax_array(v, dtype=dtype) + continue + + # Bias Diff (e.g., "layer.diff_b") + if k.endswith(".diff_b"): + key_base = k[: -len(".diff_b")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff_b"] = _to_jax_array(v, dtype=dtype) + continue + + # Weight Diff (e.g., "layer.diff") + if k.endswith(".diff"): + key_base = k[: -len(".diff")] + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base]["diff"] = _to_jax_array(v, dtype=dtype) + continue + + # Standard LoRA + m = re.match(r"^(.*?)_lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.lora\.(down|up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.(lora_down|lora_up)\.weight$", k) + if not m: + m = re.match(r"^(.*?)\.(lora_A|lora_B)\.weight$", k) + + if m: + key_base, weight_type = m.group(1), m.group(2).replace("lora_", "") + if weight_type == "A": + weight_type = "down" + elif weight_type == "B": + weight_type = "up" + if key_base not in lora_params: + lora_params[key_base] = {} + lora_params[key_base][weight_type] = _to_jax_array(v, dtype=dtype) + else: + # Fallback for exact matches of diffs if regex failed above + max_logging.log(f"Key {k} did not match any LoRA pattern.") + pass + + return lora_params + + +def _merge_lora_layer(module, weights, scale): + """Merges LoRA weights into a single non-scanned layer.""" + is_conv_kxk_locon = False + if isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) and "down" in weights and "up" in weights: + is_conv_kxk_locon = True + + updated = False + # Handle Embeddings + if isinstance(module, nnx.Embed): + if "diff" in weights and hasattr(module, "embedding"): + module.embedding.value += np.array(weights["diff"]).reshape(module.embedding.shape).astype(module.embedding.dtype) + updated = True + # Handle Norms + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + scale_diff = weights.get("diff", None) + bias_diff = weights.get("diff_b", None) + if scale_diff is not None and hasattr(module, "scale") and module.scale is not None: + module.scale.value += np.array(scale_diff).reshape(module.scale.shape).astype(module.scale.dtype) + updated = True + if bias_diff is not None and isinstance(module, nnx.LayerNorm) and hasattr(module, "bias") and module.bias is not None: + module.bias.value += np.array(bias_diff).reshape(module.bias.shape).astype(module.bias.dtype) + updated = True + elif isinstance(module, nnx.Param): + if "diff" in weights: + module.value += np.array(weights["diff"]).reshape(module.shape).astype(module.dtype) + updated = True + elif isinstance(module, (nnx.Linear, nnx.Conv)): + # Prepare LoRA terms + down_w, up_w, current_scale = None, None, None + if "down" in weights and "up" in weights and not is_conv_kxk_locon: + down_w, up_w = weights["down"], weights["up"] + down_w, up_w = np.array(down_w), np.array(up_w) # CPU convert + + # Squeeze dimensions if needed (Conv 1x1 or Linear) + if isinstance(module, nnx.Conv) and module.kernel_size == (1, 1): + down_w, up_w = np.squeeze(down_w), np.squeeze(up_w) + + rank = down_w.shape[0] if down_w.ndim > 0 else 0 + alpha = float(weights.get("alpha", rank)) + current_scale = scale * alpha / rank + + # Prepare Diff terms + w_diff = weights.get("diff", None) + b_diff = weights.get("diff_b", None) + + if w_diff is not None: + w_diff = np.array(w_diff) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if isinstance(module, nnx.Conv): + if w_diff.ndim == 5: + w_diff = w_diff.transpose((2, 3, 4, 1, 0)) + elif w_diff.ndim == 4: + w_diff = w_diff.transpose((2, 3, 1, 0)) + elif isinstance(module, nnx.Linear) and w_diff.ndim == 2: + w_diff = w_diff.transpose((1, 0)) + if b_diff is not None: + b_diff = np.array(b_diff) + + # If LoCON, compute delta and add to w_diff + if is_conv_kxk_locon: + dw, uw = np.array(weights["down"]), np.array(weights["up"]) + rank, in_c, *k_dims = dw.shape + out_c = uw.shape[0] + alpha = float(weights.get("alpha", rank)) + + delta_pt = (uw.reshape(out_c, rank) @ dw.reshape(rank, -1)).reshape(out_c, in_c, *k_dims) + + # Transpose to flax + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + + lora_delta = delta_fx * (scale * alpha / rank) + if w_diff is None: + w_diff = lora_delta.astype(np.float32) + else: + w_diff += lora_delta.astype(w_diff.dtype) + + # Check for Bias existence + bias_val = module.bias.value if module.bias is not None else None + + # --- EXECUTE JIT UPDATE --- + if down_w is not None or w_diff is not None or b_diff is not None: + new_kernel, new_bias = _compute_and_add_single_jit( + module.kernel.value, bias_val, down_w, up_w, current_scale, w_diff, b_diff + ) + + module.kernel.value = new_kernel + if new_bias is not None: + module.bias.value = new_bias + + updated = True + else: + max_logging.log("Matched key but found no actionable weights.") + return updated + + +def merge_lora(model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None, dtype: str = "float32"): + """ + Merges weights for non-scanned layers (Embeddings, singular Dense, etc). + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict, dtype=dtype) + max_logging.log(f"Parsed {len(lora_params)} unique module keys.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed, nnx.Param)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key = translate_fn(nnx_path_str) if translate_fn else None + + if lora_key and lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + if _merge_lora_layer(module, weights, scale): + assigned_count += 1 + + max_logging.log(f"Merged weights into {assigned_count} layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log( + f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}" + ) + + +def merge_lora_for_scanned( + model: nnx.Module, state_dict: dict, rank: int, scale: float, translate_fn=None, dtype: str = "float32" +): + """ + Device-Side Optimized Merge for Scanned Layers. + Now supports diff and diff_b. + """ + lora_params = parse_lora_dict(state_dict, dtype=dtype) + max_logging.log(f"Parsed {len(lora_params)} keys for scanned merge.") + matched_keys = set() + + assigned_count = 0 + for path, module in nnx.iter_graph(model): + if not isinstance(module, (nnx.Linear, nnx.Conv, nnx.LayerNorm, nnx.RMSNorm, nnx.Embed, nnx.Param)): + continue + + nnx_path_str = ".".join(map(str, path)) + lora_key_template = translate_fn(nnx_path_str) if translate_fn else None + + if not lora_key_template: + continue + + # Determine if layer is scanned based on parameter dimensions + is_scanned = False + if isinstance(module, nnx.Embed) and hasattr(module, "embedding"): + is_scanned = module.embedding.ndim > 2 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)) and hasattr(module, "scale") and module.scale is not None: + is_scanned = module.scale.ndim > 1 + elif isinstance(module, nnx.Linear): + is_scanned = module.kernel.ndim == 3 + elif isinstance(module, nnx.Conv): + is_scanned = module.kernel.ndim == 5 + elif isinstance(module, nnx.Param): + # Use template format to disambiguate: if template has {}, then it is scanned. + is_scanned = "{}" in lora_key_template + + if not is_scanned: + lora_key = lora_key_template + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + if _merge_lora_layer(module, weights, scale): + assigned_count += 1 + continue + + # If we reach here, layer is SCANNED + if isinstance(module, nnx.Embed): + num_layers = module.embedding.shape[0] + embed_diffs_to_add = np.zeros_like(module.embedding.value) + updated = False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + if "diff" in lora_params[lora_key]: + embed_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.embedding.shape[1:]) + updated = True + if updated: + module.embedding.value += embed_diffs_to_add.astype(module.embedding.dtype) + assigned_count += 1 + elif isinstance(module, (nnx.LayerNorm, nnx.RMSNorm)): + num_layers = module.scale.shape[0] + scale_diffs_to_add = np.zeros_like(module.scale.value) + bias_diffs_to_add = ( + np.zeros_like(module.bias.value) + if isinstance(module, nnx.LayerNorm) and hasattr(module, "bias") and module.bias is not None + else None + ) + updated_scale, updated_bias = False, False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + weights = lora_params[lora_key] + if "diff" in weights: + scale_diffs_to_add[i] = np.array(weights["diff"]).reshape(module.scale.shape[1:]) + updated_scale = True + if "diff_b" in weights and bias_diffs_to_add is not None: + bias_diffs_to_add[i] = np.array(weights["diff_b"]).reshape(module.bias.shape[1:]) + updated_bias = True + if updated_scale: + module.scale.value += scale_diffs_to_add.astype(module.scale.dtype) + if updated_bias and bias_diffs_to_add is not None: + module.bias.value += bias_diffs_to_add.astype(module.bias.dtype) + if updated_scale or updated_bias: + assigned_count += 1 + elif isinstance(module, nnx.Param): + num_layers = module.shape[0] + param_diffs_to_add = np.zeros_like(module.value) + updated = False + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + if "diff" in lora_params[lora_key]: + param_diffs_to_add[i] = np.array(lora_params[lora_key]["diff"]).reshape(module.shape[1:]) + updated = True + if updated: + module.value += param_diffs_to_add.astype(module.dtype) + assigned_count += 1 + elif isinstance(module, (nnx.Linear, nnx.Conv)): + is_linear = isinstance(module, nnx.Linear) + is_conv = isinstance(module, nnx.Conv) + is_conv_kxk = isinstance(module, nnx.Conv) and module.kernel_size != (1, 1) + if is_linear: + num_layers, in_feat, out_feat = module.kernel.shape + else: # Conv + num_layers = module.kernel.shape[0] + in_feat, out_feat = module.kernel.shape[3], module.kernel.shape[4] + + # 1. Scan for Rank (Fallback use rank in config file) + found_rank = rank + for i in range(num_layers): + k = lora_key_template.format(i) + if k in lora_params and "down" in lora_params[k]: + found_rank = lora_params[k]["down"].shape[0] + break + + # 2. Pre-allocate Buffers (CPU) + # LoRA Buffers + stack_down = np.zeros((num_layers, found_rank, in_feat), dtype=np.float32) + stack_up = np.zeros((num_layers, out_feat, found_rank), dtype=np.float32) + stack_alpha = np.zeros((num_layers, 1, 1), dtype=np.float32) + + # Diff Buffers + # Initialize as None, allocate only if found to save memory + stack_w_diff = None + stack_b_diff = None + + has_lora = False + has_diff = False + + for i in range(num_layers): + lora_key = lora_key_template.format(i) + if lora_key in lora_params: + matched_keys.add(lora_key) + w = lora_params[lora_key] + + # --- Fill LoRA --- + if "down" in w: + d, u = np.array(w["down"]), np.array(w["up"]) + alpha = float(w.get("alpha", d.shape[0])) + rank_ = d.shape[0] + + if is_conv_kxk: + # For LoCON kxk, compute delta and merge into stack_w_diff + rank_, in_c, *k_dims = d.shape + out_c = u.shape[0] + delta_pt = (u.reshape(out_c, rank_) @ d.reshape(rank_, -1)).reshape(out_c, in_c, *k_dims) + if delta_pt.ndim == 5: + delta_fx = delta_pt.transpose((2, 3, 4, 1, 0)) + else: + delta_fx = delta_pt.transpose((2, 3, 1, 0)) + + lora_delta = delta_fx * (scale * alpha / rank_) + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + stack_w_diff[i] += lora_delta.reshape(stack_w_diff[i].shape).astype(stack_w_diff.dtype) + has_diff = True # Mark as having diff because we merged LoRA into w_diff + else: + # For Linear or 1x1 Conv, prepare for JIT + if d.ndim > 2: + d = np.squeeze(d) + if u.ndim > 2: + u = np.squeeze(u) + stack_down[i] = d + stack_up[i] = u + stack_alpha[i] = alpha + has_lora = True + + # --- Fill Weight Diff --- + if "diff" in w: + if stack_w_diff is None: + stack_w_diff = np.zeros(module.kernel.shape, dtype=np.float32) + wd = np.array(w["diff"]) + # Transpose weights from PyTorch OIHW/OIDHW to Flax HWIO/DHWIO if needed. + if is_conv: + if wd.ndim == 5: + wd = wd.transpose((2, 3, 4, 1, 0)) + elif wd.ndim == 4: + wd = wd.transpose((2, 3, 1, 0)) + elif is_linear and wd.ndim == 2: + wd = wd.transpose((1, 0)) + + stack_w_diff[i] += wd.reshape(stack_w_diff[i].shape) + has_diff = True + + # --- Fill Bias Diff --- + if "diff_b" in w: + if stack_b_diff is None: + # Bias shape: Linear (L, Out), Conv (L, Out) usually + stack_b_diff = np.zeros((num_layers, out_feat), dtype=np.float32) + bd = np.array(w["diff_b"]) + stack_b_diff[i] = bd.flatten() + has_diff = True + + if has_lora or has_diff: + bias_val = module.bias.value if module.bias is not None else None + + # Call JIT + new_k, new_b = _compute_and_add_scanned_jit( + module.kernel.value, + stack_down if has_lora else None, + stack_up if has_lora else None, + stack_alpha if has_lora else None, + scale, + stack_w_diff, + stack_b_diff, + bias_val, + ) + + module.kernel.value = new_k + if new_b is not None: + module.bias.value = new_b + + assigned_count += 1 + else: + # Should not happen based on is_scanned logic + max_logging.log(f"Module {nnx_path_str} has scanned weights but is not Linear, Conv, Embed, or Norm type.") + continue + + max_logging.log(f"Merged weights into {assigned_count} scanned layers.") + unmatched_keys = set(lora_params.keys()) - matched_keys + if unmatched_keys: + max_logging.log( + f"{len(unmatched_keys)} key(s) in LoRA dictionary were not applied to any layer in the model: {unmatched_keys}" + ) diff --git a/src/maxdiffusion/models/ltx_video/__init__.py b/src/maxdiffusion/models/ltx_video/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/ltx_video/__init__.py +++ b/src/maxdiffusion/models/ltx_video/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/ltx_video/transformers/__init__.py b/src/maxdiffusion/models/ltx_video/transformers/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/__init__.py +++ b/src/maxdiffusion/models/ltx_video/transformers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/ltx_video/transformers/adaln.py b/src/maxdiffusion/models/ltx_video/transformers/adaln.py index 1078f084..e392e4f6 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/adaln.py +++ b/src/maxdiffusion/models/ltx_video/transformers/adaln.py @@ -126,7 +126,6 @@ def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: class AlphaCombinedTimestepSizeEmbeddings(nn.Module): - embedding_dim: int size_emb_dim: int dtype: jnp.dtype = jnp.float32 diff --git a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py index 8b12b1d8..67902936 100644 --- a/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py +++ b/src/maxdiffusion/models/ltx_video/transformers/transformer3d.py @@ -227,7 +227,6 @@ def __call__( encoder_hidden_states = self.caption_projection(encoder_hidden_states) if self.num_layers > 0: - hidden_states = self.transformer_blocks( hidden_states, freqs_cis, diff --git a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py index 6241804b..93b16851 100644 --- a/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py +++ b/src/maxdiffusion/models/ltx_video/transformers_pytorch/attention.py @@ -1,18 +1,19 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import inspect from importlib import import_module from typing import Any, Dict, Optional, Tuple diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 685b0c0b..1239ddbc 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch - Flax general utilities.""" +"""PyTorch - Flax general utilities.""" import re import torch @@ -348,10 +348,14 @@ def convert_lora_pytorch_state_dict_to_flax(pt_state_dict, params, network_alpha text_encoder_2_params = flatten_dict(unfreeze(params["text_encoder_2"])) else: text_encoder_2_params = None - (unet_state_dict, text_encoder_state_dict, text_encoder_2_state_dict, rank, network_alphas) = ( - create_flax_params_from_pytorch_state( - pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, adapter_name, is_lora=True - ) + ( + unet_state_dict, + text_encoder_state_dict, + text_encoder_2_state_dict, + rank, + network_alphas, + ) = create_flax_params_from_pytorch_state( + pt_state_dict, unet_params, text_encoder_params, text_encoder_2_params, network_alphas, adapter_name, is_lora=True ) params["unet"] = unflatten_dict(unet_state_dict) params["text_encoder"] = unflatten_dict(text_encoder_state_dict) diff --git a/src/maxdiffusion/models/normalization_flax.py b/src/maxdiffusion/models/normalization_flax.py index 2ba658d4..24f423f1 100644 --- a/src/maxdiffusion/models/normalization_flax.py +++ b/src/maxdiffusion/models/normalization_flax.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax import jax.numpy as jnp diff --git a/src/maxdiffusion/models/vae_flax.py b/src/maxdiffusion/models/vae_flax.py index dc9b0063..042ec275 100644 --- a/src/maxdiffusion/models/vae_flax.py +++ b/src/maxdiffusion/models/vae_flax.py @@ -21,6 +21,7 @@ import flax import flax.linen as nn import jax +from jax import tree_util import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict @@ -74,7 +75,6 @@ class FlaxUpsample2D(nn.Module): weights_dtype: jnp.dtype = jnp.float32 def setup(self): - self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), @@ -931,3 +931,30 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r return (sample,) return FlaxDecoderOutput(sample=sample) + + +class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution): + pass + + +def _wan_diag_gauss_dist_flatten(dist): + return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,) + + +def _wan_diag_gauss_dist_unflatten(aux, children): + mean, logvar, std, var = children + deterministic = aux[0] + obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution) + obj.mean = mean + obj.logvar = logvar + obj.std = std + obj.var = var + obj.deterministic = deterministic + return obj + + +tree_util.register_pytree_node( + WanDiagonalGaussianDistribution, + _wan_diag_gauss_dist_flatten, + _wan_diag_gauss_dist_unflatten, +) diff --git a/src/maxdiffusion/models/wan/__init__.py b/src/maxdiffusion/models/wan/__init__.py index 7e4185f3..4a62083b 100644 --- a/src/maxdiffusion/models/wan/__init__.py +++ b/src/maxdiffusion/models/wan/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 1da2d18f..0328f6ac 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Tuple, List, Sequence, Union, Optional @@ -19,16 +19,32 @@ import flax import jax import jax.numpy as jnp +from jax import tree_util from flax import nnx from ...configuration_utils import ConfigMixin from ..modeling_flax_utils import FlaxModelMixin, get_activation from ... import common_types -from ..vae_flax import (FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution, FlaxDecoderOutput) +from ..vae_flax import ( + FlaxAutoencoderKLOutput, + FlaxDiagonalGaussianDistribution, + FlaxDecoderOutput, + WanDiagonalGaussianDistribution, +) BlockSizes = common_types.BlockSizes CACHE_T = 2 -flax.config.update('flax_always_shard_variable', False) +try: + flax.config.update("flax_always_shard_variable", False) +except LookupError: + pass + + +def _update_cache(cache, idx, value): + if cache is None: + return None + return cache[:idx] + (value,) + cache[idx + 1 :] + # Helper to ensure kernel_size, stride, padding are tuples of 3 integers def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]: @@ -41,6 +57,15 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank} integers. Got {x}") +class RepSentinel: + + def __eq__(self, other): + return isinstance(other, RepSentinel) + + +tree_util.register_pytree_node(RepSentinel, lambda x: ((), None), lambda _, __: RepSentinel()) + + class WanCausalConv3d(nnx.Module): def __init__( @@ -72,10 +97,11 @@ def __init__( # Store the amount of padding needed *before* the depth dimension for caching logic self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0] + self.mesh = mesh # Set sharding dynamically based on out_channels. - num_fsdp_axis_devices = mesh.device_ids.shape[1] + num_context_axis_devices = mesh.shape["context"] kernel_sharding = (None, None, None, None, None) - if out_channels % num_fsdp_axis_devices == 0: + if out_channels % num_context_axis_devices == 0: kernel_sharding = (None, None, None, None, "conv_out") self.conv = nnx.Conv( @@ -117,6 +143,7 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None, idx=-1) -> x_padded = jnp.pad(x, padding_to_apply, mode="constant", constant_values=0.0) else: x_padded = x + out = self.conv(x_padded) return out @@ -308,30 +335,30 @@ def __init__( else: self.resample = Identity() - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): # Input x: (N, D, H, W, C), assume C = self.dim b, t, h, w, c = x.shape assert c == self.dim if self.mode == "upsample3d": if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx if feat_cache[idx] is None: - feat_cache[idx] = "Rep" - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, RepSentinel()) + feat_idx += 1 else: cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and not isinstance(feat_cache[idx], RepSentinel): # cache last frame of last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) - if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + if cache_x.shape[1] < 2 and feat_cache[idx] is not None and isinstance(feat_cache[idx], RepSentinel): cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1) - if feat_cache[idx] == "Rep": + if isinstance(feat_cache[idx], RepSentinel): x = self.time_conv(x) else: x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 x = x.reshape(b, t, h, w, 2, c) x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1) x = x.reshape(b, t * 2, h, w, c) @@ -343,17 +370,17 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array: if self.mode == "downsample3d": if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx if feat_cache[idx] is None: - feat_cache[idx] = jnp.copy(x) - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, jnp.copy(x)) + feat_idx += 1 else: cache_x = jnp.copy(x[:, -1:, :, :, :]) x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 - return x + return x, feat_cache, feat_idx class WanResidualBlock(nnx.Module): @@ -412,7 +439,7 @@ def __init__( else Identity() ) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): # Apply shortcut connection h = self.conv_shortcut(x) @@ -420,32 +447,31 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.nonlinearity(x) if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv1(x, feat_cache[idx], idx) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv1(x) x = self.norm2(x) x = self.nonlinearity(x) - idx = feat_idx[0] if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv2(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv2(x) x = x + h - return x + return x, feat_cache, feat_idx class WanAttentionBlock(nnx.Module): @@ -482,8 +508,7 @@ def __init__( precision=precision, ) - def __call__(self, x: jax.Array): - + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): identity = x batch_size, time, height, width, channels = x.shape @@ -506,7 +531,7 @@ def __call__(self, x: jax.Array): # Reshape back x = x.reshape(batch_size, time, height, width, channels) - return x + identity + return x + identity, feat_cache, feat_idx class WanMidBlock(nnx.Module): @@ -558,13 +583,13 @@ def __init__( self.attentions = nnx.data(attentions) self.resnets = nnx.data(resnets) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): - x = self.resnets[0](x, feat_cache, feat_idx) + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): + x, feat_cache, feat_idx = self.resnets[0](x, feat_cache, feat_idx) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - x = attn(x) - x = resnet(x, feat_cache, feat_idx) - return x + x, feat_cache, feat_idx = attn(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx) + return x, feat_cache, feat_idx class WanUpBlock(nnx.Module): @@ -619,19 +644,13 @@ def __init__( ) ] - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): for resnet in self.resnets: - if feat_cache is not None: - x = resnet(x, feat_cache, feat_idx) - else: - x = resnet(x) + x, feat_cache, feat_idx = resnet(x, feat_cache, feat_idx) if self.upsamplers is not None: - if feat_cache is not None: - x = self.upsamplers[0](x, feat_cache, feat_idx) - else: - x = self.upsamplers[0](x) - return x + x, feat_cache, feat_idx = self.upsamplers[0](x, feat_cache, feat_idx) + return x, feat_cache, feat_idx class WanEncoder3d(nnx.Module): @@ -740,40 +759,38 @@ def __init__( precision=precision, ) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + @nnx.jit(static_argnames="feat_idx") + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_in(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_in(x) for layer in self.down_blocks: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) + x, feat_cache, feat_idx = layer(x, feat_cache, feat_idx) - x = self.mid_block(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = self.mid_block(x, feat_cache, feat_idx) x = self.norm_out(x) x = self.nonlinearity(x) if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_out(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_out(x) - return x + return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32) class WanDecoder3d(nnx.Module): @@ -891,50 +908,47 @@ def __init__( precision=precision, ) - def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): + @nnx.jit(static_argnames="feat_idx") + def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0): if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_in(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_in(x) ## middle - x = self.mid_block(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = self.mid_block(x, feat_cache, feat_idx) ## upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx) + x, feat_cache, feat_idx = up_block(x, feat_cache, feat_idx) ## head x = self.norm_out(x) x = self.nonlinearity(x) if feat_cache is not None: - idx = feat_idx[0] + idx = feat_idx cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :]) if cache_x.shape[1] < 2 and feat_cache[idx] is not None: # cache last frame of the last two chunk cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1) x = self.conv_out(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 + feat_cache = _update_cache(feat_cache, idx, cache_x) + feat_idx += 1 else: x = self.conv_out(x) - return x + return x, feat_cache, jnp.array(feat_idx, dtype=jnp.int32) class AutoencoderKLWanCache: def __init__(self, module): self.module = module - self.clear_cache() - - def clear_cache(self): - """Resets cache dictionaries and indices""" def _count_conv3d(module): count = 0 @@ -945,12 +959,38 @@ def _count_conv3d(module): return count self._conv_num = _count_conv3d(self.module.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - # cache encode self._enc_conv_num = _count_conv3d(self.module.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num + self.init_cache() + + def init_cache(self): + """Resets cache dictionaries and indices""" + self._feat_map = (None,) * self._conv_num + # cache encode + self._enc_feat_map = (None,) * self._enc_conv_num + + +def _wan_cache_flatten(cache): + return (cache._feat_map, cache._enc_feat_map), (cache._conv_num, cache._enc_conv_num) + + +def _wan_cache_unflatten(aux, children): + conv_num, enc_conv_num = aux + feat_map, enc_feat_map = children + # Create a dummy object or one without module reference for JIT internal use + # We can't easily reconstruct 'module' but we don't need it for init_cache anymore + # if we store counts in aux. + # However, __init__ expects module. + # We will bypass __init__ for unflattening. + obj = AutoencoderKLWanCache.__new__(AutoencoderKLWanCache) + obj._conv_num = conv_num + obj._enc_conv_num = enc_conv_num + obj._feat_map = feat_map + obj._enc_feat_map = enc_feat_map + obj.module = None # module is not needed inside the trace for the cache logic now + return obj + + +tree_util.register_pytree_node(AutoencoderKLWanCache, _wan_cache_flatten, _wan_cache_unflatten) class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -1064,7 +1104,7 @@ def __init__( ) def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): - feat_cache.clear_cache() + feat_cache.init_cache() if x.shape[-1] != 3: # reshape channel last for JAX x = jnp.transpose(x, (0, 2, 3, 4, 1)) @@ -1072,21 +1112,27 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): t = x.shape[1] iter_ = 1 + (t - 1) // 4 + enc_feat_map = feat_cache._enc_feat_map + for i in range(iter_): - feat_cache._enc_conv_idx = [0] + enc_conv_idx = 0 if i == 0: - out = self.encoder(x[:, :1, :, :, :], feat_cache=feat_cache._enc_feat_map, feat_idx=feat_cache._enc_conv_idx) + out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx) else: - out_ = self.encoder( + out_, enc_feat_map, enc_conv_idx = self.encoder( x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], - feat_cache=feat_cache._enc_feat_map, - feat_idx=feat_cache._enc_conv_idx, + feat_cache=enc_feat_map, + feat_idx=enc_conv_idx, ) out = jnp.concatenate([out, out_], axis=1) + + # Update back to the wrapper object if needed, but for result we use local vars + feat_cache._enc_feat_map = enc_feat_map + enc = self.quant_conv(out) mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :] enc = jnp.concatenate([mu, logvar], axis=-1) - feat_cache.clear_cache() + feat_cache.init_cache() return enc def encode( @@ -1094,7 +1140,7 @@ def encode( ) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]: """Encode video into latent distribution.""" h = self._encode(x, feat_cache) - posterior = FlaxDiagonalGaussianDistribution(h) + posterior = WanDiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) @@ -1102,15 +1148,18 @@ def encode( def _decode( self, z: jax.Array, feat_cache: AutoencoderKLWanCache, return_dict: bool = True ) -> Union[FlaxDecoderOutput, jax.Array]: - feat_cache.clear_cache() + feat_cache.init_cache() iter_ = z.shape[1] x = self.post_quant_conv(z) + + dec_feat_map = feat_cache._feat_map + for i in range(iter_): - feat_cache._conv_idx = [0] + conv_idx = 0 if i == 0: - out = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) else: - out_ = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=feat_cache._feat_map, feat_idx=feat_cache._conv_idx) + out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) # This is to bypass an issue where frame[1] should be frame[2] and vise versa. # Ideally shouldn't need to do this however, can't find where the frame is going out of sync. @@ -1128,8 +1177,11 @@ def _decode( fm3 = jnp.expand_dims(fm3, axis=axis) fm4 = jnp.expand_dims(fm4, axis=axis) out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1) + + feat_cache._feat_map = dec_feat_map + out = jnp.clip(out, min=-1.0, max=1.0) - feat_cache.clear_cache() + feat_cache.init_cache() if not return_dict: return (out,) diff --git a/src/maxdiffusion/models/wan/transformers/__init__.py b/src/maxdiffusion/models/wan/transformers/__init__.py index 9ff757fc..4a62083b 100644 --- a/src/maxdiffusion/models/wan/transformers/__init__.py +++ b/src/maxdiffusion/models/wan/transformers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index a18b127c..16d70764 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from typing import Tuple, Optional, Dict, Union, Any @@ -19,7 +19,6 @@ import math import jax import jax.numpy as jnp -from jax.sharding import PartitionSpec from jax.ad_checkpoint import checkpoint_name from flax import nnx import flax.linen as nn @@ -104,7 +103,7 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, - flash_min_seq_length: int = 4096 + flash_min_seq_length: int = 4096, ): self.timesteps_proj = NNXFlaxTimesteps(dim=time_freq_dim, flip_sin_to_cos=True, freq_shift=0) self.time_embedder = NNXTimestepEmbedding( @@ -149,7 +148,7 @@ def __init__( dtype=dtype, weights_dtype=weights_dtype, precision=precision, - flash_min_seq_length=flash_min_seq_length + flash_min_seq_length=flash_min_seq_length, ) def __call__( @@ -261,11 +260,11 @@ def conditional_named_scope(self, name: str): return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: - hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) - hidden_states = checkpoint_name(hidden_states, "ffn_activation") - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) - with jax.named_scope("proj_out"): - return self.proj_out(hidden_states) # output is (4, 75600, 5120) + hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) + hidden_states = checkpoint_name(hidden_states, "ffn_activation") + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + with jax.named_scope("proj_out"): + return self.proj_out(hidden_states) # output is (4, 75600, 5120) class WanTransformerBlock(nnx.Module): @@ -292,7 +291,6 @@ def __init__( mask_padding_tokens: bool = True, enable_jax_named_scopes: bool = False, ): - self.enable_jax_named_scopes = enable_jax_named_scopes # 1. Self-attention @@ -381,9 +379,11 @@ def __call__( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads")) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) hidden_states = checkpoint_name(hidden_states, "hidden_states") - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) + axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv")) + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names) # 1. Self-attention with self.conditional_named_scope("self_attn"): @@ -412,7 +412,7 @@ def __call__( encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs, - encoder_attention_mask = encoder_attention_mask + encoder_attention_mask=encoder_attention_mask, ) with self.conditional_named_scope("cross_attn_residual"): hidden_states = hidden_states + attn_output @@ -504,7 +504,7 @@ def __init__( text_embed_dim=text_dim, image_embed_dim=image_dim, pos_embed_seq_len=pos_embed_seq_len, - flash_min_seq_length=flash_min_seq_length + flash_min_seq_length=flash_min_seq_length, ) # 3. Transformer blocks @@ -539,7 +539,7 @@ def init_block(rngs): if scan_layers: self.blocks = init_block(rngs) else: - blocks = nnx.List([]) + blocks = [] for _ in range(num_layers): block = WanTransformerBlock( rngs=rngs, @@ -561,7 +561,7 @@ def init_block(rngs): enable_jax_named_scopes=enable_jax_named_scopes, ) blocks.append(block) - self.blocks = blocks + self.blocks = nnx.data(blocks) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( @@ -583,7 +583,7 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() - @jax.named_scope('WanModel') + @jax.named_scope("WanModel") def __call__( self, hidden_states: jax.Array, @@ -609,24 +609,37 @@ def __call__( hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) with self.conditional_named_scope("condition_embedder"): - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: - encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) - if encoder_attention_mask is not None: - text_mask = jnp.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), dtype=jnp.int32) - encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) - encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + if encoder_attention_mask is not None: + text_mask = jnp.ones( + (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), + dtype=jnp.int32, + ) + encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) + encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) if self.scan_layers: def scan_fn(carry, block): hidden_states_carry, rngs_carry = carry hidden_states = block( - hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry, encoder_attention_mask + hidden_states_carry, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs_carry, + encoder_attention_mask, ) new_carry = (hidden_states, rngs_carry) return new_carry, None @@ -647,7 +660,15 @@ def scan_fn(carry, block): for block in self.blocks: def layer_forward(hidden_states): - return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs, encoder_attention_mask=encoder_attention_mask) + return block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs, + encoder_attention_mask=encoder_attention_mask, + ) rematted_layer_forward = self.gradient_checkpoint.apply( layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py index 1e1f7ae5..fc3e67e3 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py @@ -104,15 +104,11 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), ("embed", None) - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) # 2. Self-attention - self.norm1 = FP32LayerNorm( - rngs=rngs, dim=dim, eps=eps, elementwise_affine=False - ) + self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) self.attn1 = FlaxWanAttention( rngs=rngs, query_dim=dim, @@ -150,9 +146,7 @@ def __init__( residual_checkpoint_name="cross_attn", ) assert cross_attn_norm is True, "cross_attn_norm must be True" - self.norm2 = FP32LayerNorm( - rngs=rngs, dim=dim, eps=eps, elementwise_affine=True - ) + self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) # 4. Feed-forward self.ffn = WanFeedForward( @@ -166,9 +160,7 @@ def __init__( dropout=dropout, ) - self.norm3 = FP32LayerNorm( - rngs=rngs, dim=dim, eps=eps, elementwise_affine=False - ) + self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) # 5. Output projection self.proj_out = nnx.data([None]) @@ -180,9 +172,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), ("embed", None) - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) key = rngs.params() @@ -205,19 +195,15 @@ def __call__( control_hidden_states = self.proj_in(control_hidden_states) control_hidden_states = control_hidden_states + hidden_states - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - jnp.split( - (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 - ) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) control_hidden_states = jax.lax.with_sharding_constraint( control_hidden_states, PartitionSpec("data", "fsdp", "tensor"), ) - control_hidden_states = checkpoint_name( - control_hidden_states, "control_hidden_states" - ) + control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states") encoder_hidden_states = jax.lax.with_sharding_constraint( encoder_hidden_states, PartitionSpec("data", "fsdp", None), @@ -225,11 +211,9 @@ def __call__( # 1. Self-attention with jax.named_scope("attn1"): - norm_hidden_states = ( - self.norm1(control_hidden_states.astype(jnp.float32)) - * (1 + scale_msa) - + shift_msa - ).astype(control_hidden_states.dtype) + norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + control_hidden_states.dtype + ) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, @@ -237,15 +221,13 @@ def __call__( deterministic=deterministic, rngs=rngs, ) - control_hidden_states = ( - control_hidden_states.astype(jnp.float32) + attn_output * gate_msa - ).astype(control_hidden_states.dtype) + control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype( + control_hidden_states.dtype + ) # 2. Cross-attention with jax.named_scope("attn2"): - norm_hidden_states = self.norm2( - control_hidden_states.astype(jnp.float32) - ).astype(control_hidden_states.dtype) + norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype) attn_output = self.attn2( hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -256,17 +238,12 @@ def __call__( # 3. Feed-forward with jax.named_scope("ffn"): - norm_hidden_states = ( - self.norm3(control_hidden_states.astype(jnp.float32)) - * (1 + c_scale_msa) - + c_shift_msa - ).astype(control_hidden_states.dtype) - ff_output = self.ffn( - norm_hidden_states, deterministic=deterministic, rngs=rngs + norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + control_hidden_states.dtype ) + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) control_hidden_states = ( - control_hidden_states.astype(jnp.float32) - + ff_output.astype(jnp.float32) * c_gate_msa + control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa ).astype(control_hidden_states.dtype) conditioning_states = None if self.apply_output_projection: @@ -327,9 +304,7 @@ def __init__( self.scan_layers = scan_layers # 1. Patch & position embedding - self.rope = WanRotaryPosEmbed( - attention_head_dim, patch_size, rope_max_seq_len - ) + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nnx.Conv( in_channels, inner_dim, @@ -356,9 +331,7 @@ def __init__( pos_embed_seq_len=pos_embed_seq_len, ) - self.gradient_checkpoint = GradientCheckpointType.from_str( - remat_policy - ) + self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) self.names_which_can_be_offloaded = names_which_can_be_offloaded self.names_which_can_be_saved = names_which_can_be_saved @@ -432,9 +405,7 @@ def __init__( ), ) - self.norm_out = FP32LayerNorm( - rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False - ) + self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( rngs=rngs, in_features=inner_dim, @@ -442,16 +413,12 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), ("embed", None) - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), ) key = rngs.params() self.scale_shift_table = nnx.Param( jax.random.normal(key, (1, 2, inner_dim)) / inner_dim**0.5, - kernel_init=nnx.with_partitioning( - nnx.initializers.xavier_uniform(), (None, None, "embed") - ), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), ) @jax.named_scope("WanVACEModel") @@ -468,9 +435,7 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ) -> jax.Array: - hidden_states = nn.with_logical_constraint( - hidden_states, ("batch", None, None, None, None) - ) + hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size post_patch_num_frames = num_frames // p_t @@ -478,9 +443,7 @@ def __call__( post_patch_width = width // p_w if control_hidden_states_scale is None: - control_hidden_states_scale = jnp.ones_like( - control_hidden_states, shape=(len(self.config.vace_layers),) - ) + control_hidden_states_scale = jnp.ones_like(control_hidden_states, shape=(len(self.config.vace_layers),)) if control_hidden_states_scale.shape[0] != len(self.config.vace_layers): raise ValueError( "Length of `control_hidden_states_scale`" @@ -489,9 +452,7 @@ def __call__( ) hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) - control_hidden_states = jnp.transpose( - control_hidden_states, (0, 2, 3, 4, 1) - ) + control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1)) rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) @@ -505,15 +466,17 @@ def __call__( hidden_states.shape[2] - control_hidden_states.shape[2], )) - control_hidden_states = jnp.concatenate( - [control_hidden_states, control_hidden_states_padding], axis=2 - ) + control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2) # Condition embedder is a FC layer. - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( - self.condition_embedder( # We will need to mask out the text embedding. - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + _, + ) = self.condition_embedder( # We will need to mask out the text embedding. + timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) @@ -526,6 +489,7 @@ def __call__( # Prepare VACE hints control_hidden_states_list = nnx.List([]) for i, vace_block in enumerate(self.vace_blocks): + def layer_forward(hidden_states, control_hidden_states): return vace_block( hidden_states=hidden_states, @@ -543,12 +507,8 @@ def layer_forward(hidden_states, control_hidden_states): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - conditioning_states, control_hidden_states = rematted_layer_forward( - hidden_states, control_hidden_states - ) - control_hidden_states_list.append( - (conditioning_states, control_hidden_states_scale[i]) - ) + conditioning_states, control_hidden_states = rematted_layer_forward(hidden_states, control_hidden_states) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) control_hidden_states_list = control_hidden_states_list[::-1] @@ -576,13 +536,9 @@ def layer_forward_vace(hidden_states): hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify - shift, scale = jnp.split( - self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1 - ) + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) - hidden_states = ( - self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift - ).astype(hidden_states.dtype) + hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) with jax.named_scope("proj_out"): hidden_states = self.proj_out(hidden_states) # Linear layer. diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py old mode 100644 new mode 100755 index 7a4b8841..2e76bab9 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import os @@ -72,7 +72,7 @@ def rename_for_custom_trasformer(key): return renamed_pt_key -def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers): +def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers=40): if scan_layers: if "blocks" in pt_tuple_key: new_key = ("blocks",) + pt_tuple_key[2:] @@ -89,7 +89,7 @@ def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_d if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] else: - new_tensor = jnp.zeros((40,) + flax_tensor.shape) + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) flax_tensor = new_tensor.at[block_index].set(flax_tensor) return flax_key, flax_tensor @@ -186,7 +186,6 @@ def load_wan_transformer( scan_layers: bool = True, subfolder: str = "", ): - if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: @@ -260,23 +259,23 @@ def load_base_wan_transformer( renamed_pt_key = rename_key(pt_key) if "condition_embedder" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") - renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") - renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") if "image_embedder" in renamed_pt_key: - if "net.0.proj" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") - elif "net_0.proj" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") - if "net.2" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") - renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") - if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("weight", "scale") - renamed_pt_key = renamed_pt_key.replace("kernel", "scale") + if "net.0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") + elif "net_0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") + if "net.2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") + renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") + if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("weight", "scale") + renamed_pt_key = renamed_pt_key.replace("kernel", "scale") renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") @@ -285,7 +284,7 @@ def load_base_wan_transformer( renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers) + flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) diff --git a/src/maxdiffusion/multihost_dataloading.py b/src/maxdiffusion/multihost_dataloading.py index 4be0ba8d..273ded82 100644 --- a/src/maxdiffusion/multihost_dataloading.py +++ b/src/maxdiffusion/multihost_dataloading.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=unused-import """SPMD Multihost Dataloading Utilities. diff --git a/src/maxdiffusion/pedagogical_examples/attention_comparison.py b/src/maxdiffusion/pedagogical_examples/attention_comparison.py index 024ef92a..6981e092 100644 --- a/src/maxdiffusion/pedagogical_examples/attention_comparison.py +++ b/src/maxdiffusion/pedagogical_examples/attention_comparison.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import os import time diff --git a/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py b/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py index 28230251..47a521fb 100644 --- a/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py +++ b/src/maxdiffusion/pedagogical_examples/checkpoint_params_restore.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """This script is used an example of how to restore params from a orbax train_state ckpt.""" diff --git a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py index 16c015a7..cc547cc8 100644 --- a/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py +++ b/src/maxdiffusion/pedagogical_examples/dataset_tf_cache_to_tfrecord.py @@ -1,18 +1,19 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import os import argparse import tensorflow as tf diff --git a/src/maxdiffusion/pedagogical_examples/parameter_count.py b/src/maxdiffusion/pedagogical_examples/parameter_count.py index 8e591b4e..e9fe4542 100644 --- a/src/maxdiffusion/pedagogical_examples/parameter_count.py +++ b/src/maxdiffusion/pedagogical_examples/parameter_count.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ from typing import Sequence from absl import app import jax @@ -21,7 +22,6 @@ def run(config): - pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( config.pretrained_model_name_or_path, revision=config.revision, diff --git a/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py b/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py index 08a0f46a..350f6d0c 100644 --- a/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py +++ b/src/maxdiffusion/pedagogical_examples/save_sd_checkpoint.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Load and save a checkpoint. This is useful for uploading checkpoints to gcs and later loading them from gcs directly. diff --git a/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py b/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py index 64aa0f0b..860f4beb 100644 --- a/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py +++ b/src/maxdiffusion/pedagogical_examples/save_sdxl_checkpoint.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """Load and save a checkpoint. This is useful for uploading checkpoints to gcs and later loading them from gcs directly. diff --git a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py index 6298adda..a0a38021 100644 --- a/src/maxdiffusion/pedagogical_examples/to_tfrecords.py +++ b/src/maxdiffusion/pedagogical_examples/to_tfrecords.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Example file of how to prepare tfrecords with latents and hidden_states preprocessed. @@ -54,14 +54,12 @@ dl_manager = tfds.download.DownloadManager(download_dir="/tmp") tmp_dataset = "dataset" -TRANSFORMS = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(size=512), - transforms.Normalize([0.5], [0.5]), - ] -) +TRANSFORMS = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize(size=512, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(size=512), + transforms.Normalize([0.5], [0.5]), +]) def delete_files(path): @@ -184,7 +182,6 @@ def img_to_latents(img, p_vae_apply, sample_rng): def run(config): - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( config.pretrained_model_name_or_path, revision=config.revision, diff --git a/src/maxdiffusion/pedagogical_examples/unet_shardings.py b/src/maxdiffusion/pedagogical_examples/unet_shardings.py index bc956b1f..38ed1af9 100644 --- a/src/maxdiffusion/pedagogical_examples/unet_shardings.py +++ b/src/maxdiffusion/pedagogical_examples/unet_shardings.py @@ -1,20 +1,20 @@ #!/usr/bin/python3 """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """This script is used an example of how to shard the UNET on TPU.""" diff --git a/src/maxdiffusion/pipelines/__init__.py b/src/maxdiffusion/pipelines/__init__.py index 227784ba..019c79a8 100644 --- a/src/maxdiffusion/pipelines/__init__.py +++ b/src/maxdiffusion/pipelines/__init__.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import TYPE_CHECKING @@ -51,16 +51,14 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) else: - _import_structure["stable_diffusion"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) + _import_structure["stable_diffusion"].extend([ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ]) try: if not is_flax_available(): @@ -82,20 +80,15 @@ _import_structure["controlnet"].extend( ["FlaxStableDiffusionControlNetPipeline", "FlaxStableDiffusionXLControlNetPipeline"] ) - _import_structure["stable_diffusion"].extend( - [ - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - ] - ) - _import_structure["stable_diffusion_xl"].extend( - [ - "FlaxStableDiffusionXLPipeline", - ] - ) + _import_structure["stable_diffusion"].extend([ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ]) + _import_structure["stable_diffusion_xl"].extend([ + "FlaxStableDiffusionXLPipeline", + ]) if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/maxdiffusion/pipelines/controlnet/__init__.py b/src/maxdiffusion/pipelines/controlnet/__init__.py index e650f9d5..fe7070b0 100644 --- a/src/maxdiffusion/pipelines/controlnet/__init__.py +++ b/src/maxdiffusion/pipelines/controlnet/__init__.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py b/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py index 885b0b37..b8b1cc18 100644 --- a/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py +++ b/src/maxdiffusion/pipelines/controlnet/pipeline_flax_controlnet_sdxl.py @@ -112,7 +112,6 @@ def __call__( output_type: str = None, jit: bool = False, ): - if isinstance(guidance_scale, float) and jit: # Convert to a tensor so each device gets a copy. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) diff --git a/src/maxdiffusion/pipelines/flux/__init__.py b/src/maxdiffusion/pipelines/flux/__init__.py index 5457eef5..39ea05b5 100644 --- a/src/maxdiffusion/pipelines/flux/__init__.py +++ b/src/maxdiffusion/pipelines/flux/__init__.py @@ -1,18 +1,19 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ _import_structure = {"pipeline_jflux": "JfluxPipeline"} from .flux_pipeline import ( diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py index 112338d5..15b2c4f5 100644 --- a/src/maxdiffusion/pipelines/flux/flux_pipeline.py +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -131,7 +131,6 @@ def prepare_latents( dtype: jnp.dtype, rng: Array, ): - # VAE applies 8x compression on images but we must also account for packing which # requires latent height and width to be divisibly by 2. height = 2 * (height // (vae_scale_factor * 2)) @@ -194,7 +193,6 @@ def get_t5_prompt_embeds( encode_in_batches=False, encode_batch_size=None, ): - prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -243,7 +241,6 @@ def encode_prompt( encode_in_batches: bool = False, encode_batch_size: int = None, ): - if encode_in_batches: assert encode_in_batches is not None @@ -271,7 +268,6 @@ def encode_prompt( def _generate( self, flux_params, vae_params, latents, latent_image_ids, prompt_embeds, txt_ids, vec, guidance_vec, c_ts, p_ts ): - def loop_body( step, args, diff --git a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py index 1b8f4deb..4aa3baf1 100644 --- a/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py @@ -901,7 +901,6 @@ def transformer_forward_pass( skip_layer_mask, skip_layer_strategy, ): - noise_pred = transformer.apply( {"params": state.params}, hidden_states=latents, diff --git a/src/maxdiffusion/pipelines/pipeline_flax_utils.py b/src/maxdiffusion/pipelines/pipeline_flax_utils.py index 8507d96e..da3a755b 100644 --- a/src/maxdiffusion/pipelines/pipeline_flax_utils.py +++ b/src/maxdiffusion/pipelines/pipeline_flax_utils.py @@ -473,7 +473,7 @@ def load_module(name, value): class_obj = import_flax_or_no_model(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in importable_classes.keys()} + class_candidates = dict.fromkeys(importable_classes.keys(), class_obj) else: # else we just import it from the library. diff --git a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py index 9ac32eb7..564b0dfa 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion/__init__.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ from typing import TYPE_CHECKING from ...utils import ( @@ -84,13 +85,11 @@ StableDiffusionPix2PixZeroPipeline, ) - _dummy_objects.update( - { - "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, - "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, - "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, - } - ) + _dummy_objects.update({ + "StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline, + "StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline, + "StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline, + }) else: _import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"] _import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"] diff --git a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py index 2eb01334..1ae1b641 100644 --- a/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py +++ b/src/maxdiffusion/pipelines/stable_diffusion_xl/__init__.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ from typing import TYPE_CHECKING from ...utils import ( diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py index 83a537f8..9a17b1e7 100644 --- a/src/maxdiffusion/pipelines/wan/__init__.py +++ b/src/maxdiffusion/pipelines/wan/__init__.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ from .wan_pipeline import WanPipeline diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 0bc93f0c..7c0314b4 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -94,9 +94,13 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder: str = "", ): - def create_model(rngs: nnx.Rngs, wan_config: dict): wan_transformer = WanModel(**wan_config, rngs=rngs) return wan_transformer @@ -111,7 +115,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): # WAN 2.2 I2V uses VAE-encoded latent conditioning (image_dim and added_kv_proj_dim are None in the transformer config) if config.model_name == "wan2.1": if wan_config.get("image_seq_len") is None: - wan_config["image_seq_len"] = 257 + wan_config["image_seq_len"] = 257 wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype @@ -201,6 +205,7 @@ class WanPipeline: vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ + def __init__( self, tokenizer: AutoTokenizer, @@ -252,21 +257,18 @@ def load_tokenizer(cls, config: HyperParameters): @classmethod def load_image_encoder(cls, config: HyperParameters): - image_processor = CLIPImageProcessor.from_pretrained( - config.pretrained_model_name_or_path, subfolder="image_processor" - ) + image_processor = CLIPImageProcessor.from_pretrained(config.pretrained_model_name_or_path, subfolder="image_processor") try: - image_encoder = FlaxCLIPVisionModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="image_encoder", dtype=jnp.float32 - ) + image_encoder = FlaxCLIPVisionModel.from_pretrained( + config.pretrained_model_name_or_path, subfolder="image_encoder", dtype=jnp.float32 + ) except Exception as e: - max_logging.error(f"Failed to load FlaxCLIPVisionModel: {e}") - raise + max_logging.error(f"Failed to load FlaxCLIPVisionModel: {e}") + raise return image_processor, image_encoder @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - def create_model(rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( config.pretrained_model_name_or_path, @@ -374,7 +376,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline return model max_logging.log("Quantizing transformer with Qwix.") - batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) + batch_size = config.global_batch_size_to_train_on latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) model_inputs = (latents, timesteps, prompt_embeds) with mesh: @@ -384,10 +386,22 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline @classmethod def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): + cls, + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder="transformer", + ): with mesh: wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder + devices_array=devices_array, + mesh=mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder=subfolder, ) return wan_transformer @@ -401,17 +415,16 @@ def load_scheduler(cls, config): return scheduler, scheduler_state def encode_image(self, image: PipelineImageInput, num_videos_per_prompt: int = 1): - if not isinstance(image, list): - image = [image] - image_inputs = self.image_processor(images=image, return_tensors="np") - pixel_values = jnp.array(image_inputs.pixel_values) - - image_encoder_output = self.image_encoder(pixel_values, output_hidden_states=True) - image_embeds = image_encoder_output.hidden_states[-2] + if not isinstance(image, list): + image = [image] + image_inputs = self.image_processor(images=image, return_tensors="np") + pixel_values = jnp.array(image_inputs.pixel_values) - image_embeds = jnp.repeat(image_embeds, num_videos_per_prompt, axis=0) - return image_embeds + image_encoder_output = self.image_encoder(pixel_values, output_hidden_states=True) + image_embeds = image_encoder_output.hidden_states[-2] + image_embeds = jnp.repeat(image_embeds, num_videos_per_prompt, axis=0) + return image_embeds def _get_t5_prompt_embeds( self, @@ -508,82 +521,93 @@ def prepare_latents_i2v_base( dtype: jnp.dtype, last_image: Optional[jax.Array] = None, ) -> Tuple[jax.Array, jax.Array]: - """ - Encodes the initial image(s) into latents to be used as conditioning. - Returns: - latent_condition: The VAE encoded latents of the image(s). - video_condition: The input to the VAE. - """ - height, width = image.shape[-2:] - image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W] - - if last_image is None: - video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 - ) - else: - last_image = last_image[:, :, jnp.newaxis, :, :] - video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], axis=2 - ) - - vae_dtype = getattr(self.vae, "dtype", jnp.float32) - video_condition = video_condition.astype(vae_dtype) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() - - # Normalize latents - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) - latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) - latent_condition = encoded_output - latent_condition = latent_condition.astype(dtype) - latent_condition = (latent_condition - latents_mean) / latents_std - - return latent_condition, video_condition + """ + Encodes the initial image(s) into latents to be used as conditioning. + Returns: + latent_condition: The VAE encoded latents of the image(s). + video_condition: The input to the VAE. + """ + height, width = image.shape[-2:] + image = image[:, :, jnp.newaxis, :, :] # [B, C, 1, H, W] + + if last_image is None: + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 + ) + else: + last_image = last_image[:, :, jnp.newaxis, :, :] + video_condition = jnp.concatenate( + [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], + axis=2, + ) + + vae_dtype = getattr(self.vae, "dtype", jnp.float32) + video_condition = video_condition.astype(vae_dtype) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + data_mesh_size = self.mesh.shape[self.config.mesh_axes[0]] + if video_condition.shape[0] % data_mesh_size == 0: + sharding_spec = P(self.config.mesh_axes[0], None, None, None, None) + video_condition = jax.lax.with_sharding_constraint(video_condition, sharding_spec) + encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode() + + # Normalize latents + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) + latents_std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latent_condition = encoded_output + latent_condition = latent_condition.astype(dtype) + latent_condition = (latent_condition - latents_mean) / latents_std + + return latent_condition, video_condition def _denormalize_latents(self, latents: jax.Array) -> jax.Array: - """Denormalizes latents using VAE statistics.""" - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) - return latents + """Denormalizes latents using VAE statistics.""" + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(jnp.float32) + return latents def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: - """Decodes latents to video frames and postprocesses.""" - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] + """Decodes latents to video frames and postprocesses.""" + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - return self.video_processor.postprocess_video(video, output_type="np") + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + return self.video_processor.postprocess_video(video, output_type="np") @classmethod def _create_common_components(cls, config, vae_only=False, i2v=False): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - components = { - "vae": wan_vae, "vae_cache": vae_cache, - "devices_array": devices_array, "rngs": rngs, "mesh": mesh, - "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None, - "image_processor": None, "image_encoder": None - } - - if not vae_only: - components["tokenizer"] = cls.load_tokenizer(config=config) - components["text_encoder"] = cls.load_text_encoder(config=config) - components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) - if i2v and config.model_name == 'wan2.1': - components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) - return components + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + components = { + "vae": wan_vae, + "vae_cache": vae_cache, + "devices_array": devices_array, + "rngs": rngs, + "mesh": mesh, + "tokenizer": None, + "text_encoder": None, + "scheduler": None, + "scheduler_state": None, + "image_processor": None, + "image_encoder": None, + } + + if not vae_only: + components["tokenizer"] = cls.load_tokenizer(config=config) + components["text_encoder"] = cls.load_text_encoder(config=config) + components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) + if i2v and config.model_name == "wan2.1": + components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) + return components @abstractmethod def _get_num_channel_latents(self) -> int: @@ -603,7 +627,7 @@ def _prepare_model_inputs_i2v( last_image: Optional[PIL.Image.Image] = None, ): if prompt is not None and isinstance(prompt, str): - prompt = [prompt] + prompt = [prompt] batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt effective_batch_size = batch_size * num_videos_per_prompt @@ -617,30 +641,29 @@ def _prepare_model_inputs_i2v( negative_prompt_embeds=negative_prompt_embeds, ) - # 2. Encode Image (only for WAN 2.1 I2V which uses CLIP image embeddings) # WAN 2.2 I2V does not use CLIP image embeddings, it uses VAE latent conditioning instead transformer_dtype = self.config.activations_dtype if self.config.model_name == "wan2.1": - # WAN 2.1 I2V: Use CLIP image encoder - if image_embeds is None: - images_to_encode = [image] - if last_image is None: - images_to_encode = [image] - else: - images_to_encode = [image, last_image] - image_embeds = self.encode_image(images_to_encode, num_videos_per_prompt=num_videos_per_prompt) - self.image_seq_len = image_embeds.shape[1] - - if batch_size > 1: - image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1)) - - image_embeds = image_embeds.astype(transformer_dtype) + # WAN 2.1 I2V: Use CLIP image encoder + if image_embeds is None: + images_to_encode = [image] + if last_image is None: + images_to_encode = [image] + else: + images_to_encode = [image, last_image] + image_embeds = self.encode_image(images_to_encode, num_videos_per_prompt=num_videos_per_prompt) + self.image_seq_len = image_embeds.shape[1] + + if batch_size > 1: + image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1)) + + image_embeds = image_embeds.astype(transformer_dtype) else: - # WAN 2.2 I2V: No CLIP image embeddings, set to None or empty tensor - # The actual image conditioning happens via VAE latents in prepare_latents - image_embeds = None + # WAN 2.2 I2V: No CLIP image embeddings, set to None or empty tensor + # The actual image conditioning happens via VAE latents in prepare_latents + image_embeds = None prompt_embeds = prompt_embeds.astype(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype) @@ -648,7 +671,7 @@ def _prepare_model_inputs_i2v( # Use same sharding logic as T2V pipeline for consistent behavior data_sharding = NamedSharding(self.mesh, P()) if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) @@ -656,22 +679,21 @@ def _prepare_model_inputs_i2v( return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size - def _prepare_model_inputs( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: max_logging.log( @@ -724,8 +746,9 @@ def _prepare_model_inputs( @abstractmethod def __call__(self, **kwargs): - """Runs the inference pipeline.""" - pass + """Runs the inference pipeline.""" + pass + @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( @@ -740,7 +763,12 @@ def transformer_forward_pass( encoder_hidden_states_image=None, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image) + noise_pred = wan_transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=encoder_hidden_states_image, + ) if do_classifier_free_guidance: bsz = latents.shape[0] // 2 noise_cond = noise_pred[:bsz] # First half = conditional diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 5617e3b7..62c1a34a 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -17,14 +17,17 @@ from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial +from contextlib import nullcontext from flax import nnx from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipeline2_1(WanPipeline): """Pipeline for WAN 2.1 with a single transformer.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): super().__init__(config=config, **kwargs) self.transformer = transformer @@ -41,27 +44,27 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer" + subfolder="transformer", ) pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - transformer=transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - config=config, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, ) return pipeline, transformer @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) return pipeline @@ -74,20 +77,20 @@ def _get_num_channel_latents(self) -> int: return self.transformer.config.in_channels def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: Optional[jax.Array] = None, - prompt_embeds: Optional[jax.Array] = None, - negative_prompt_embeds: Optional[jax.Array] = None, - vae_only: bool = False, + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + vae_only: bool = False, ): latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, @@ -113,8 +116,15 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state, ) + # Set the TE shard_guard context_manager if using TE cudnn_flash attention + if self.config.attention == "cudnn_flash_te": + from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error + + shard_guard = global_shard_guard(MeshResource(cp_resource="context")) + else: + shard_guard = nullcontext() - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard: latents = p_run_inference( graphdef=graphdef, sharded_state=state, @@ -126,6 +136,7 @@ def __call__( latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) + def run_inference_2_1( graphdef, sharded_state, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index c0400f60..82261eda 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -17,42 +17,52 @@ from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial +from contextlib import nullcontext from flax import nnx from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipeline2_2(WanPipeline): """Pipeline for WAN 2.2 with dual transformers.""" - def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + + def __init__( + self, + config: HyperParameters, + low_noise_transformer: Optional[WanModel], + high_noise_transformer: Optional[WanModel], + **kwargs + ): super().__init__(config=config, **kwargs) self.low_noise_transformer = low_noise_transformer self.high_noise_transformer = high_noise_transformer + self.boundary_ratio = config.boundary_ratio @classmethod def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): common_components = cls._create_common_components(config, vae_only) low_noise_transformer, high_noise_transformer = None, None if not vae_only and load_transformer: - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer_2" - ) - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer" - ) - - pipeline = cls( + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2", + ) + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + + pipeline = cls( tokenizer=common_components["tokenizer"], text_encoder=common_components["text_encoder"], low_noise_transformer=low_noise_transformer, @@ -64,7 +74,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t devices_array=common_components["devices_array"], mesh=common_components["mesh"], config=config, - ) + ) return pipeline, low_noise_transformer, high_noise_transformer @classmethod @@ -76,29 +86,30 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform @classmethod def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init( + config, restored_checkpoint, vae_only, load_transformer + ) return pipeline def _get_num_channel_latents(self) -> int: return self.low_noise_transformer.config.in_channels def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, ): latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, @@ -118,17 +129,26 @@ def __call__( low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + p_run_inference = partial( run_inference_2_2, guidance_scale_low=guidance_scale_low, guidance_scale_high=guidance_scale_high, - boundary=boundary, + boundary=boundary_timestep, num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, ) + # Set the TE shard_guard context_manager if using TE cudnn_flash attention + if self.config.attention == "cudnn_flash_te": + from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error + + shard_guard = global_shard_guard(MeshResource(cp_resource="context")) + else: + shard_guard = nullcontext() - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard: latents = p_run_inference( low_noise_graphdef=low_noise_graphdef, low_noise_state=low_noise_state, @@ -143,6 +163,7 @@ def __call__( latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) + def run_inference_2_2( low_noise_graphdef, low_noise_state, @@ -167,17 +188,27 @@ def run_inference_2_2( def low_noise_branch(operands): latents, timestep, prompt_embeds = operands return transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_low + low_noise_graphdef, + low_noise_state, + low_noise_rest, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance, + guidance_scale_low, ) def high_noise_branch(operands): latents, timestep, prompt_embeds = operands return transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_high + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents, + timestep, + prompt_embeds, + do_classifier_free_guidance, + guidance_scale_high, ) for step in range(num_inference_steps): @@ -192,10 +223,7 @@ def high_noise_branch(operands): # - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise). # - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise). noise_pred, latents = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latents, timestep, prompt_embeds) + use_high_noise, high_noise_branch, low_noise_branch, (latents, timestep, prompt_embeds) ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 0380a07c..0622ec79 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -26,8 +26,10 @@ from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipelineI2V_2_1(WanPipeline): """Pipeline for WAN 2.1 Image-to-Video.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): super().__init__(config=config, **kwargs) self.transformer = transformer @@ -44,28 +46,28 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer" + subfolder="transformer", ) pipeline = cls( - tokenizer=common_components["tokenizer"], - text_encoder=common_components["text_encoder"], - image_processor=common_components["image_processor"], - image_encoder=common_components["image_encoder"], - transformer=transformer, - vae=common_components["vae"], - vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], - scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - config=config, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + image_processor=common_components["image_processor"], + image_encoder=common_components["image_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, ) return pipeline, transformer @classmethod def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) return pipeline @@ -87,110 +89,109 @@ def prepare_latents( last_image: Optional[jax.Array] = None, num_videos_per_prompt: int = 1, ) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: - - if hasattr(image, "detach"): - image = image.detach().cpu().numpy() - image = jnp.array(image) - - if last_image is not None: - if hasattr(last_image, "detach"): - last_image = last_image.detach().cpu().numpy() - last_image = jnp.array(last_image) - - if num_videos_per_prompt > 1: - image = jnp.repeat(image, num_videos_per_prompt, axis=0) - if last_image is not None: - last_image = jnp.repeat(last_image, num_videos_per_prompt, axis=0) - - num_channels_latents = self.vae.z_dim - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - latent_height = height // self.vae_scale_factor_spatial - latent_width = width // self.vae_scale_factor_spatial - - shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) - - if latents is None: - latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) - else: - latents = latents.astype(dtype) - latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) - mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) - if last_image is None: - mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) - else: - mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) - mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) - mask_lat_size = mask_lat_size.reshape( - batch_size, - 1, - num_latent_frames, - self.vae_scale_factor_temporal, - latent_height, - latent_width - ) - mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) - condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) - return latents, condition, None - + if hasattr(image, "detach"): + image = image.detach().cpu().numpy() + image = jnp.array(image) + + if last_image is not None: + if hasattr(last_image, "detach"): + last_image = last_image.detach().cpu().numpy() + last_image = jnp.array(last_image) + + if num_videos_per_prompt > 1: + image = jnp.repeat(image, num_videos_per_prompt, axis=0) + if last_image is not None: + last_image = jnp.repeat(last_image, num_videos_per_prompt, axis=0) + + num_channels_latents = self.vae.z_dim + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + + if latents is None: + latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) + else: + latents = latents.astype(dtype) + latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) + mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) + if last_image is None: + mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) + else: + mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) + mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) + mask_lat_size = mask_lat_size.reshape( + batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) + condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) + return latents, condition, None def __call__( - self, - prompt: Union[str, List[str]], - image: PipelineImageInput, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, - latents: Optional[jax.Array] = None, - prompt_embeds: Optional[jax.Array] = None, - negative_prompt_embeds: Optional[jax.Array] = None, - image_embeds: Optional[jax.Array] = None, - last_image: Optional[PipelineImageInput] = None, - output_type: Optional[str] = "np", - rng: Optional[jax.Array] = None, + self, + prompt: Union[str, List[str]], + image: PipelineImageInput, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + image_embeds: Optional[jax.Array] = None, + last_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "np", + rng: Optional[jax.Array] = None, ): - height = height or self.config.height width = width or self.config.width num_frames = num_frames or self.config.num_frames # Validate and adjust num_frames to ensure proper reshaping in prepare_latents if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " - f"Rounding {num_frames} to the nearest valid number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - max_logging.log(f"Adjusted num_frames to: {num_frames}") + max_logging.log( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + f"Rounding {num_frames} to the nearest valid number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( - prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length, - prompt_embeds, negative_prompt_embeds, image_embeds, last_image + prompt, + image, + negative_prompt, + num_videos_per_prompt, + max_sequence_length, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + last_image, ) def _process_image_input(img_input, height, width, num_videos_per_prompt): - if img_input is None: - return None - tensor = self.video_processor.preprocess(img_input, height=height, width=width) - jax_array = jnp.array(tensor.cpu().numpy()) - if jax_array.ndim == 3: - jax_array = jax_array[None, ...] # Add batch dimension - if num_videos_per_prompt > 1: - jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) - return jax_array + if img_input is None: + return None + tensor = self.video_processor.preprocess(img_input, height=height, width=width) + jax_array = jnp.array(tensor.cpu().numpy()) + if jax_array.ndim == 3: + jax_array = jax_array[None, ...] # Add batch dimension + if num_videos_per_prompt > 1: + jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) + return jax_array image_tensor = _process_image_input(image, height, width, effective_batch_size) last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) if rng is None: - rng = jax.random.key(self.config.seed) + rng = jax.random.key(self.config.seed) latents_rng, inference_rng = jax.random.split(rng) latents, condition, first_frame_mask = self.prepare_latents( @@ -213,7 +214,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) data_sharding = NamedSharding(self.mesh, P()) if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) latents = jax.device_put(latents, data_sharding) condition = jax.device_put(condition, data_sharding) @@ -221,7 +222,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) image_embeds = jax.device_put(image_embeds, data_sharding) if first_frame_mask is not None: - first_frame_mask = jax.device_put(first_frame_mask, data_sharding) + first_frame_mask = jax.device_put(first_frame_mask, data_sharding) p_run_inference = partial( run_inference_2_1_i2v, @@ -233,7 +234,6 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): scheduler=self.scheduler, ) - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( latents=latents, @@ -252,7 +252,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): def run_inference_2_1_i2v( - graphdef, sharded_state, rest_of_state, + graphdef, + sharded_state, + rest_of_state, latents: jnp.array, condition: jnp.array, prompt_embeds: jnp.array, @@ -273,14 +275,18 @@ def run_inference_2_1_i2v( t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] latents_input = latents if do_classifier_free_guidance: - latents_input = jnp.concatenate([latents, latents], axis=0) + latents_input = jnp.concatenate([latents, latents], axis=0) latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) noise_pred, _ = transformer_forward_pass( - graphdef, sharded_state, rest_of_state, - latent_model_input, timestep, prompt_embeds, + graphdef, + sharded_state, + rest_of_state, + latent_model_input, + timestep, + prompt_embeds, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index ab24a651..1f65f452 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -26,9 +26,17 @@ from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler + class WanPipelineI2V_2_2(WanPipeline): """Pipeline for WAN 2.2 Image-to-Video.""" - def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + + def __init__( + self, + config: HyperParameters, + low_noise_transformer: Optional[WanModel], + high_noise_transformer: Optional[WanModel], + **kwargs, + ): super().__init__(config=config, **kwargs) self.low_noise_transformer = low_noise_transformer self.high_noise_transformer = high_noise_transformer @@ -39,26 +47,38 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t common_components = cls._create_common_components(config, vae_only, i2v=True) low_noise_transformer, high_noise_transformer = None, None if not vae_only: - if load_transformer: - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], mesh=common_components["mesh"], - rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer" - ) - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], mesh=common_components["mesh"], - rngs=common_components["rngs"], config=config, restored_checkpoint=restored_checkpoint, - subfolder="transformer_2" - ) + if load_transformer: + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2", + ) pipeline = cls( - tokenizer=common_components["tokenizer"], text_encoder=common_components["text_encoder"], - image_processor=common_components["image_processor"], image_encoder=common_components["image_encoder"], - low_noise_transformer=low_noise_transformer, high_noise_transformer=high_noise_transformer, - vae=common_components["vae"], vae_cache=common_components["vae_cache"], - scheduler=common_components["scheduler"], scheduler_state=common_components["scheduler_state"], - devices_array=common_components["devices_array"], mesh=common_components["mesh"], - config=config, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + image_processor=common_components["image_processor"], + image_encoder=common_components["image_encoder"], + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, ) return pipeline, low_noise_transformer, high_noise_transformer @@ -75,27 +95,26 @@ def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_ return pipeline def prepare_latents( - self, - image: jax.Array, - batch_size: int, - height: int, - width: int, - num_frames: int, - dtype: jnp.dtype, - rng: jax.Array, - latents: Optional[jax.Array] = None, - last_image: Optional[jax.Array] = None, - num_videos_per_prompt: int = 1, -) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: - + self, + image: jax.Array, + batch_size: int, + height: int, + width: int, + num_frames: int, + dtype: jnp.dtype, + rng: jax.Array, + latents: Optional[jax.Array] = None, + last_image: Optional[jax.Array] = None, + num_videos_per_prompt: int = 1, + ) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]: if hasattr(image, "detach"): - image = image.detach().cpu().numpy() + image = image.detach().cpu().numpy() image = jnp.array(image) if last_image is not None: - if hasattr(last_image, "detach"): - last_image = last_image.detach().cpu().numpy() - last_image = jnp.array(last_image) + if hasattr(last_image, "detach"): + last_image = last_image.detach().cpu().numpy() + last_image = jnp.array(last_image) num_channels_latents = self.vae.z_dim num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 @@ -105,16 +124,16 @@ def prepare_latents( shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) if latents is None: - latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) + latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) else: - latents = latents.astype(dtype) + latents = latents.astype(dtype) latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image) mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype) if last_image is None: - mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) + mask_lat_size = mask_lat_size.at[:, :, 1:, :, :].set(0) else: - mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) + mask_lat_size = mask_lat_size.at[:, :, 1:-1, :, :].set(0) first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) @@ -127,59 +146,67 @@ def prepare_latents( return latents, condition, None def __call__( - self, - prompt: Union[str, List[str]], - image: PipelineImageInput, - negative_prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, - latents: Optional[jax.Array] = None, - prompt_embeds: Optional[jax.Array] = None, - negative_prompt_embeds: Optional[jax.Array] = None, - image_embeds: Optional[jax.Array] = None, - last_image: Optional[PipelineImageInput] = None, - output_type: Optional[str] = "np", - rng: Optional[jax.Array] = None, + self, + prompt: Union[str, List[str]], + image: PipelineImageInput, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + image_embeds: Optional[jax.Array] = None, + last_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "np", + rng: Optional[jax.Array] = None, ): height = height or self.config.height width = width or self.config.width num_frames = num_frames or self.config.num_frames if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " - f"Rounding {num_frames} to the nearest valid number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - max_logging.log(f"Adjusted num_frames to: {num_frames}") + max_logging.log( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + f"Rounding {num_frames} to the nearest valid number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v( - prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length, - prompt_embeds, negative_prompt_embeds, image_embeds, last_image + prompt, + image, + negative_prompt, + num_videos_per_prompt, + max_sequence_length, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + last_image, ) + def _process_image_input(img_input, height, width, num_videos_per_prompt): - if img_input is None: - return None - tensor = self.video_processor.preprocess(img_input, height=height, width=width) - jax_array = jnp.array(tensor.cpu().numpy()) - if jax_array.ndim == 3: - jax_array = jax_array[None, ...] # Add batch dimension - if num_videos_per_prompt > 1: - jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) - return jax_array + if img_input is None: + return None + tensor = self.video_processor.preprocess(img_input, height=height, width=width) + jax_array = jnp.array(tensor.cpu().numpy()) + if jax_array.ndim == 3: + jax_array = jax_array[None, ...] # Add batch dimension + if num_videos_per_prompt > 1: + jax_array = jnp.repeat(jax_array, num_videos_per_prompt, axis=0) + return jax_array image_tensor = _process_image_input(image, height, width, effective_batch_size) last_image_tensor = _process_image_input(last_image, height, width, effective_batch_size) if rng is None: - rng = jax.random.key(self.config.seed) + rng = jax.random.key(self.config.seed) latents_rng, inference_rng = jax.random.split(rng) # For WAN 2.2, image_embeds may be None (no CLIP image encoder) @@ -206,17 +233,16 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) data_sharding = NamedSharding(self.mesh, P()) if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) latents = jax.device_put(latents, data_sharding) condition = jax.device_put(condition, data_sharding) prompt_embeds = jax.device_put(prompt_embeds, data_sharding) negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) # WAN 2.2 I2V doesn't use image_embeds (it's None), but we still need to pass it to the function if image_embeds is not None: - image_embeds = jax.device_put(image_embeds, data_sharding) + image_embeds = jax.device_put(image_embeds, data_sharding) if first_frame_mask is not None: - first_frame_mask = jax.device_put(first_frame_mask, data_sharding) - + first_frame_mask = jax.device_put(first_frame_mask, data_sharding) boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps @@ -232,10 +258,16 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, low_noise_state=low_noise_state, low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, high_noise_state=high_noise_state, high_noise_rest=high_noise_rest, - latents=latents, condition=condition, - prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, + latents=latents, + condition=condition, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, scheduler_state=scheduler_state, ) latents = jnp.transpose(latents, (0, 4, 1, 2, 3)) @@ -245,9 +277,14 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): return latents return self._decode_latents_to_video(latents) + def run_inference_2_2_i2v( - low_noise_graphdef, low_noise_state, low_noise_rest, - high_noise_graphdef, high_noise_state, high_noise_rest, + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, latents: jnp.array, condition: jnp.array, prompt_embeds: jnp.array, @@ -260,51 +297,59 @@ def run_inference_2_2_i2v( scheduler: FlaxUniPCMultistepScheduler, scheduler_state, ): - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 - def high_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands - latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) - noise_pred, latents_out = transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents_input, ts_input, pe_input, - do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_high, - encoder_hidden_states_image=ie_input - ) - return noise_pred, latents_out - - def low_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands - latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) - noise_pred, latents_out = transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents_input, ts_input, pe_input, - do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_low, - encoder_hidden_states_image=ie_input - ) - return noise_pred, latents_out + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + + def high_noise_branch(operands): + latents_input, ts_input, pe_input, ie_input = operands + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + noise_pred, latents_out = transformer_forward_pass( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents_input, + ts_input, + pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale_high, + encoder_hidden_states_image=ie_input, + ) + return noise_pred, latents_out + + def low_noise_branch(operands): + latents_input, ts_input, pe_input, ie_input = operands + latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) + noise_pred, latents_out = transformer_forward_pass( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + latents_input, + ts_input, + pe_input, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale_low, + encoder_hidden_states_image=ie_input, + ) + return noise_pred, latents_out + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + # WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder + if image_embeds is not None: + image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) + condition = jnp.concatenate([condition] * 2) + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + latents_input = latents if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - # WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder - if image_embeds is not None: - image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) - condition = jnp.concatenate([condition] * 2) - - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - latents_input = latents - if do_classifier_free_guidance: - latents_input = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) - timestep = jnp.broadcast_to(t, latents_input.shape[0]) - - use_high_noise = jnp.greater_equal(t, boundary) - noise_pred, _ = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latent_model_input, timestep, prompt_embeds, image_embeds) - ) - noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents + latents_input = jnp.concatenate([latents, latents], axis=0) + latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + use_high_noise = jnp.greater_equal(t, boundary) + noise_pred, _ = jax.lax.cond( + use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds) + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py similarity index 84% rename from src/maxdiffusion/pipelines/wan/wan_vace_pipeline.py rename to src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py index 60c7f33f..487cc85e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_vace_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py @@ -31,14 +31,13 @@ from ...models.wan.transformers.transformer_wan_vace import WanVACEModel from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler from ...models.modeling_flax_pytorch_utils import torch2jax -from .wan_pipeline import WanPipeline, cast_with_exclusion +from .wan_pipeline import cast_with_exclusion +from .wan_pipeline_2_1 import WanPipeline2_1 import torch import PIL -def retrieve_latents( - encoder_output: torch.Tensor, rngs=None, sample_mode: str = "sample" -): +def retrieve_latents(encoder_output: torch.Tensor, rngs=None, sample_mode: str = "sample"): """Extracts the latent codes from the encoder object. From https://github.com/huggingface/diffusers/blob/8d415a6f481ff1b26168c046267628419650f930/src/diffusers/pipelines/wan/pipeline_wan_vace.py#L128C1-L128C4 @@ -55,9 +54,13 @@ def retrieve_latents( # For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder: str = "", ): - def create_model(rngs: nnx.Rngs, wan_config: dict): wan_vace_transformer = WanVACEModel(**wan_config, rngs=rngs) return wan_vace_transformer @@ -125,12 +128,12 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): return wan_transformer -class VaceWanPipeline(WanPipeline): +class VaceWanPipeline2_1(WanPipeline2_1): r"""Pipeline for video generation using Wan + VACE. Currently it only supports reference image(s) + text to video generation. - It extends `WanPipeline` to support additional conditioning signals. + It extends `WanPipeline2_1` to support additional conditioning signals. tokenizer ([`T5Tokenizer`]): Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), @@ -147,35 +150,31 @@ class VaceWanPipeline(WanPipeline): """ def preprocess_conditions( - self, - video: Optional[PipelineImageInput] = None, - mask: Optional[PipelineImageInput] = None, - reference_images: Optional[PipelineImageInput] = None, - batch_size: int = 1, - height: int = 480, - width: int = 832, - num_frames: int = 81, - dtype = None, -): + self, + video: Optional[PipelineImageInput] = None, + mask: Optional[PipelineImageInput] = None, + reference_images: Optional[PipelineImageInput] = None, + batch_size: int = 1, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype=None, + ): """Prepares the conditional data for inference. Based on https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/pipelines/wan/pipeline_wan_vace.py#L414 """ if video is not None: - base = self.vae_scale_factor_spatial * ( - self.transformer.config.patch_size[1] - if self.transformer is not None - else self.transformer_2.config.patch_size[1] - ) + base = self.vae_scale_factor_spatial * (self.transformer.config.patch_size[1]) video_height, video_width = self.video_processor.get_default_height_width(video[0]) if video_height * video_width > height * width: - scale = min(width / video_width, height / video_width) - video_height, video_width = int(video_height * scale), int(video_width * scale) + scale = min(width / video_width, height / video_width) + video_height, video_width = int(video_height * scale), int(video_width * scale) if video_height % base != 0 or video_width % base != 0: - video_height = (video_height // base) * base - video_width = (video_width // base) * base + video_height = (video_height // base) * base + video_width = (video_width // base) * base assert video_height * video_width <= height * width @@ -183,9 +182,7 @@ def preprocess_conditions( video = jnp.array(np.asarray(video), dtype=dtype) image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling) else: - video = jnp.zeros( - (batch_size, 3, num_frames, height, width), dtype=dtype - ) + video = jnp.zeros((batch_size, 3, num_frames, height, width), dtype=dtype) image_size = (height, width) # Use the height/width provider by user if mask is not None: @@ -202,9 +199,7 @@ def preprocess_conditions( # per video if reference_images is None or isinstance(reference_images, PIL.Image.Image): reference_images = [[reference_images] for _ in range(video.shape[0])] - elif isinstance(reference_images, (list, tuple)) and isinstance( - next(iter(reference_images)), PIL.Image.Image - ): + elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image): reference_images = [reference_images] elif ( isinstance(reference_images, (list, tuple)) @@ -243,14 +238,16 @@ def preprocess_conditions( # TODO: should we use jax/TF-based resizing here? resized_image = torch.nn.functional.interpolate( image, size=(new_height, new_width), mode="bilinear", align_corners=False - ).squeeze(0) # [C, H, W] + ).squeeze( + 0 + ) # [C, H, W] top = (image_size[0] - new_height) // 2 left = (image_size[1] - new_width) // 2 canvas = torch.ones(3, *image_size, dtype=torch.float32) canvas[:, top : top + new_height, left : left + new_width] = resized_image - canvas = canvas.permute(1, 2, 0) # Bring back to Jax + canvas = canvas.permute(1, 2, 0) # Bring back to Jax canvas = torch2jax(canvas) preprocessed_images.append(canvas) @@ -276,15 +273,9 @@ def prepare_masks( ) if mask.shape[0] != 1: - raise ValueError( - "Generating with more than one video is not yet supported. This may be supported in the future." - ) + raise ValueError("Generating with more than one video is not yet supported. This may be supported in the future.") - transformer_patch_size = ( - self.transformer.config.patch_size[1] - if self.transformer is not None - else self.transformer_2.config.patch_size[1] - ) + transformer_patch_size = self.transformer.config.patch_size[1] mask_list = [] for mask_, reference_images_batch in zip(mask, reference_images): @@ -293,14 +284,12 @@ def prepare_masks( new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size mask_ = mask_[0, :, :, :] - mask_ = mask_.view( - num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial - ) + mask_ = mask_.view(num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial) # TODO: should we refactor to use Jax/TF? mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width] mask_ = torch.nn.functional.interpolate( - mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" - ).squeeze(0) + mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" + ).squeeze(0) num_ref_images = len(reference_images_batch) if num_ref_images > 0: mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :]) @@ -312,10 +301,22 @@ def prepare_masks( @classmethod def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): + cls, + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder="transformer", + ): with mesh: wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder + devices_array=devices_array, + mesh=mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder=subfolder, ) return wan_transformer @@ -333,7 +334,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform if not vae_only: if load_transformer: with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") + transformer = cls.load_transformer( + devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer" + ) text_encoder = cls.load_text_encoder(config=config) tokenizer = cls.load_tokenizer(config=config) @@ -374,12 +377,8 @@ def check_inputs( ): if self.transformer is not None: base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] - elif self.transformer_2 is not None: - base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1] else: - raise ValueError( - "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline" - ) + raise ValueError("`transformer` component must be set in order to run inference with this pipeline") if height % base != 0 or width % base != 0: raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") @@ -389,52 +388,50 @@ def check_inputs( if prompt is not None and prompt_embeds is not None: raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" - " only forward one of the two." + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif negative_prompt is not None and ( - not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) - ): + elif negative_prompt is not None and (not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") if video is not None: if mask is not None: if len(video) != len(mask): raise ValueError( - f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that" - " they have the same length." + f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that" + " they have the same length." ) if reference_images is not None: is_pil_image = isinstance(reference_images, PIL.Image.Image) is_list_of_pil_images = isinstance(reference_images, list) and all( - isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images + isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images ) is_list_of_list_of_pil_images = isinstance(reference_images, list) and all( - isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img) - for ref_img in reference_images + isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img) + for ref_img in reference_images ) if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images): raise ValueError( - "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " - "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" ) if is_list_of_list_of_pil_images and len(reference_images) != 1: raise ValueError( - "The pipeline only supports generating one video at a time at the moment. When passing a list " - "of list of reference images, where the outer list corresponds to the batch size and the inner " - "list corresponds to list of conditioning images per video, please make sure to only pass " - "one inner list of reference images (i.e., `[[, , ...]]`" + "The pipeline only supports generating one video at a time at the moment. When passing a list " + "of list of reference images, where the outer list corresponds to the batch size and the inner " + "list corresponds to list of conditioning images per video, please make sure to only pass " + "one inner list of reference images (i.e., `[[, , ...]]`" ) elif mask is not None: raise ValueError("`mask` can only be passed if `video` is passed as well.") @@ -445,7 +442,6 @@ def __call__( mask: Optional[List[PipelineImageInput]] = None, reference_images: Optional[List[PipelineImageInput]] = None, conditioning_scale: Union[float, List[float], torch.Tensor] = 1.0, - prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 480, @@ -491,7 +487,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, video=video, mask=mask, - reference_images=reference_images + reference_images=reference_images, ) if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: @@ -507,9 +503,7 @@ def __call__( batch_size = len(prompt) if num_videos_per_prompt != 1: - raise ValueError( - "Generating multiple videos per prompt is not yet supported. This may be supported in the future." - ) + raise ValueError("Generating multiple videos per prompt is not yet supported. This may be supported in the future.") prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, @@ -520,20 +514,15 @@ def __call__( ) transformer_dtype = self.transformer.proj_out.bias.dtype - - vace_layers = ( - self.transformer.config.vace_layers - if self.transformer is not None - else self.transformer_2.config.vace_layers - ) + vace_layers = self.transformer.config.vace_layers if isinstance(conditioning_scale, (int, float)): conditioning_scale = [conditioning_scale] * len(vace_layers) if isinstance(conditioning_scale, list): if len(conditioning_scale) != len(vace_layers): raise ValueError( - f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}." - ) + f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}." + ) conditioning_scale = jnp.array(conditioning_scale) if isinstance(conditioning_scale, jax.Array): if conditioning_scale.shape[0] != len(vace_layers): @@ -557,7 +546,9 @@ def __call__( if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding)) - conditioning_latents = self.prepare_video_latents(data_sharding=data_sharding, video=video, mask=mask, reference_images=reference_images, rngs=None) + conditioning_latents = self.prepare_video_latents( + data_sharding=data_sharding, video=video, mask=mask, reference_images=reference_images, rngs=None + ) mask = self.prepare_masks(mask, reference_images) conditioning_latents = conditioning_latents.transpose(0, 4, 1, 2, 3) @@ -640,7 +631,6 @@ def prepare_video_latents( reference_images: Optional[List[List[torch.Tensor]]] = None, rngs=None, ) -> jax.Array: - if reference_images is None: # For each batch of video, we set no re # ference image (as one or more can be passed by user) @@ -652,9 +642,7 @@ def prepare_video_latents( ) if video.shape[0] != 1: - raise ValueError( - "Generating with more than one video is not yet supported. This may be supported in the future." - ) + raise ValueError("Generating with more than one video is not yet supported. This may be supported in the future.") vae_dtype = self.vae.decoder.conv_in.conv.bias.dtype video = video.astype(dtype=vae_dtype) @@ -683,7 +671,9 @@ def prepare_video_latents( reference_image = jax.device_put(reference_image, data_sharding) reference_image = reference_image[None, None, :, :, :] # [1, 1, H, W, C] - reference_latent = retrieve_latents(self.vae.encode(reference_image, feat_cache=self.vae_cache), rngs=None, sample_mode="argmax") + reference_latent = retrieve_latents( + self.vae.encode(reference_image, feat_cache=self.vae_cache), rngs=None, sample_mode="argmax" + ) reference_latent = ((reference_latent.astype(jnp.float32) - latents_mean) * latents_std).astype(vae_dtype) @@ -739,7 +729,6 @@ def run_inference( num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, scheduler_state, - control_hidden_states, control_hidden_states_scale, ): @@ -763,7 +752,6 @@ def run_inference( prompt_embeds, control_hidden_states=control_hidden_states, control_hidden_states_scale=control_hidden_states_scale, - do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, ) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 27c9f645..61f17932 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" # pylint: disable=missing-module-docstring import os @@ -32,6 +32,7 @@ _ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2} _ALLOWED_TRAINING_MODEL_NAMES = {WAN2_1} + def _validate_model_name(model_name: str | None): """Raise if model_name is not in the allowed list.""" if model_name is None: @@ -39,12 +40,16 @@ def _validate_model_name(model_name: str | None): if model_name not in _ALLOWED_MODEL_NAMES: raise ValueError(f"Invalid config.model_name '{model_name}'. Allowed values: {sorted(_ALLOWED_MODEL_NAMES)}") + def _validate_training_model_name(model_name: str | None): """Raise if model_name is not in the allowed training list.""" if model_name is None: return if model_name not in _ALLOWED_TRAINING_MODEL_NAMES: - raise ValueError(f"Invalid config.model_name '{model_name}' for training. Allowed values: {sorted(_ALLOWED_TRAINING_MODEL_NAMES)}") + raise ValueError( + f"Invalid config.model_name '{model_name}' for training. Allowed values: {sorted(_ALLOWED_TRAINING_MODEL_NAMES)}" + ) + def string_to_bool(s: str) -> bool: if s.lower() == "true": @@ -196,12 +201,14 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. if raw_keys["attention"] == "ring" or raw_keys["attention_sharding_uniform"]: - max_logging.log(f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set.") + max_logging.log( + f"Adding sequence sharding to q and kv if not already present because {raw_keys['attention']}=='ring' or {raw_keys['attention_sharding_uniform']} is set." + ) logical_axis_rules = list(raw_keys["logical_axis_rules"]) max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") new_rules = [] - q_seq_sharding = (LENGTH, "fsdp") - kv_seq_sharding = (KV_LENGTH, "fsdp") + q_seq_sharding = (LENGTH, "context") + kv_seq_sharding = (KV_LENGTH, "context") if q_seq_sharding not in logical_axis_rules: logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: @@ -211,7 +218,7 @@ def user_init(raw_keys): if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") new_rules.append(ring_attention_axis_rule) - else: # attention =flash but sequence parallel sharding requested for both self and cross attention + else: # attention =flash but sequence parallel sharding requested for both self and cross attention for seq_parallel_axis_rule in SEQUENCE_PARALLEL_AXIS_RULES: if seq_parallel_axis_rule not in logical_axis_rules: max_logging.log(f"Adding sequence parallel attention axis rule {seq_parallel_axis_rule}") @@ -244,9 +251,10 @@ def user_init(raw_keys): raw_keys["total_train_batch_size"] = max_utils.get_global_batch_size(raw_keys["per_device_batch_size"]) raw_keys["num_slices"] = get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) - raw_keys["global_batch_size_to_load"], raw_keys["global_batch_size_to_train_on"] = ( - _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) - ) + ( + raw_keys["global_batch_size_to_load"], + raw_keys["global_batch_size_to_train_on"], + ) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"]) def get_num_slices(raw_keys): diff --git a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py index 218117eb..c55a49c4 100644 --- a/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -528,13 +528,11 @@ def step_2(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: ) def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: - timestep_list = jnp.array( - [ - state.timesteps[step_index - 2], - state.timesteps[step_index - 1], - state.timesteps[step_index], - ] - ) + timestep_list = jnp.array([ + state.timesteps[step_index - 2], + state.timesteps[step_index - 1], + state.timesteps[step_index], + ]) return self.multistep_dpm_solver_third_order_update( state, state.model_outputs, diff --git a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py index 863fa26c..f6ace5fc 100644 --- a/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py @@ -256,7 +256,6 @@ def add_noise( timesteps: jnp.ndarray, flux: bool = False, ) -> jnp.ndarray: - if flux: t = state.timesteps[timesteps] t = t[:, None, None] diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index 03a47fd4..b2c7d96a 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -136,13 +136,11 @@ def __init__( if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if ( - sum( - [ - self.config.use_beta_sigmas, - self.config.use_exponential_sigmas, - self.config.use_karras_sigmas, - ] - ) + sum([ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ]) > 1 ): raise ValueError( diff --git a/src/maxdiffusion/schedulers/scheduling_utils_flax.py b/src/maxdiffusion/schedulers/scheduling_utils_flax.py index e1690ba8..d38f1446 100644 --- a/src/maxdiffusion/schedulers/scheduling_utils_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_utils_flax.py @@ -262,7 +262,8 @@ def create(cls, scheduler): elif config.beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. betas = ( - jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) ** 2 + jnp.linspace(config.beta_start**0.5, config.beta_end**0.5, config.num_train_timesteps, dtype=scheduler.dtype) + ** 2 ) elif config.beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule diff --git a/src/maxdiffusion/tests/__init__.py b/src/maxdiffusion/tests/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/tests/__init__.py +++ b/src/maxdiffusion/tests/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index c2180240..f345ab11 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest @@ -39,7 +39,7 @@ def test_splash_attention(self): pyconfig.initialize( [ None, - os.path.join(THIS_DIR, "..", "configs", "base21.yml"), + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), 'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,' '"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,' '"block_q_dq": 512, "block_kv_dq": 512}', diff --git a/src/maxdiffusion/tests/configuration_utils_test.py b/src/maxdiffusion/tests/configuration_utils_test.py index 29f3f8a7..df46ea75 100644 --- a/src/maxdiffusion/tests/configuration_utils_test.py +++ b/src/maxdiffusion/tests/configuration_utils_test.py @@ -1,18 +1,19 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import json import os diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index 4cb290ef..a58d5dcc 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import os import unittest from unittest.mock import Mock diff --git a/src/maxdiffusion/tests/generate_flux_smoke_test.py b/src/maxdiffusion/tests/generate_flux_smoke_test.py index 68968bfd..4f174716 100644 --- a/src/maxdiffusion/tests/generate_flux_smoke_test.py +++ b/src/maxdiffusion/tests/generate_flux_smoke_test.py @@ -1,18 +1,19 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index a5bb289f..e2b4d772 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import os import unittest import pytest diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index d0c02044..b6722b3a 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import os import unittest import pytest diff --git a/src/maxdiffusion/tests/gradient_checkpoint_test.py b/src/maxdiffusion/tests/gradient_checkpoint_test.py index ca237d52..a4d6f6cd 100644 --- a/src/maxdiffusion/tests/gradient_checkpoint_test.py +++ b/src/maxdiffusion/tests/gradient_checkpoint_test.py @@ -1,17 +1,17 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import unittest diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 1141ec8c..0b55c8f8 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -1,17 +1,17 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. """ import os @@ -70,7 +70,6 @@ def setUp(self): InputPipelineInterface.dummy_data = {} def test_make_dreambooth_train_iterator(self): - instance_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/instance_class" class_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/class_class" local_dir = "/tmp/" @@ -135,7 +134,9 @@ def test_make_dreambooth_train_iterator(self): cleanup(instance_class_local_dir) cleanup(class_class_local_dir) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_pokemon_hf_iterator(self): pyconfig.initialize( [ @@ -239,7 +240,9 @@ def test_make_pokemon_hf_iterator_sdxl(self): assert data["input_ids"].shape == (device_count, 2, 77) assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_pokemon_tf_iterator_cache(self): pyconfig.initialize( [ @@ -302,7 +305,9 @@ def test_make_pokemon_tf_iterator_cache(self): config.resolution // vae_scale_factor, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_pokemon_iterator_no_cache(self): pyconfig.initialize( [ @@ -435,7 +440,9 @@ def test_make_pokemon_iterator_sdxl_cache(self): config.resolution // vae_scale_factor, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_laion_grain_iterator(self): try: subprocess.check_output( @@ -492,7 +499,9 @@ def test_make_laion_grain_iterator(self): 8, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_make_laion_tfrecord_iterator(self): pyconfig.initialize( [ @@ -553,7 +562,9 @@ def _parse_tfrecord_fn(example): 8, ) - @pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace") + @pytest.mark.skip( + "This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace" + ) def test_tfrecord(self): """Validate latents match a deterministic output image""" diff --git a/src/maxdiffusion/tests/legacy_hf_tests/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/tests/conftest.py b/src/maxdiffusion/tests/legacy_hf_tests/conftest.py similarity index 79% rename from tests/conftest.py rename to src/maxdiffusion/tests/legacy_hf_tests/conftest.py index 42d0bac8..730f6f27 100644 --- a/tests/conftest.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/conftest.py @@ -31,14 +31,14 @@ def pytest_addoption(parser): - from maxdiffusion.utils.testing_utils import pytest_addoption_shared + from maxdiffusion.utils.testing_utils import pytest_addoption_shared - pytest_addoption_shared(parser) + pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main + from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main - make_reports = terminalreporter.config.getoption("--make-reports") - if make_reports: - pytest_terminal_summary_main(terminalreporter, id=make_reports) + make_reports = terminalreporter.config.getoption("--make-reports") + if make_reports: + pytest_terminal_summary_main(terminalreporter, id=make_reports) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py new file mode 100644 index 00000000..1caabdbf --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py @@ -0,0 +1,83 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import inspect + +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + + +@require_flax +class FlaxModelTesterMixin: + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py new file mode 100644 index 00000000..f514f708 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py @@ -0,0 +1,118 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import gc +import unittest + +from maxdiffusion import FlaxUNet2DConditionModel +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import load_hf_numpy, require_flax, slow +from parameterized import parameterized + + +if is_flax_available(): + import jax + import jax.numpy as jnp + + +@slow +@require_flax +class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): + + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + revision = "bf16" if fp16 else None + + model, params = FlaxUNet2DConditionModel.from_pretrained(model_id, subfolder="unet", dtype=dtype, revision=revision) + return model, params + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], + [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], + [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], + [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], + # fmt: on + ] + ) + def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py new file mode 100644 index 00000000..295ed508 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py @@ -0,0 +1,55 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +from maxdiffusion import FlaxAutoencoderKL +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import require_flax + +from .test_modeling_common_flax import FlaxModelTesterMixin + + +if is_flax_available(): + import jax + + +@require_flax +class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): + model_class = FlaxAutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + prng_key = jax.random.PRNGKey(0) + image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) + + return {"sample": image, "prng_key": prng_key} + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_01.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_01.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_01.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_01.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_02.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_02.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_02.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_02.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_03.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_03.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_03.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_03.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_04.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_04.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_04.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_04.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_05.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_05.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_05.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_05.npy diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py new file mode 100644 index 00000000..45583a2f --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py @@ -0,0 +1,939 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import tempfile +import unittest +from typing import Dict, List, Tuple + +from maxdiffusion import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from jax import random + + jax_device = jax.default_backend() + + +@require_flax +class FlaxSchedulerCommonTest(unittest.TestCase): + scheduler_classes = () + forward_default_kwargs = () + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + key1, key2 = random.split(random.PRNGKey(0)) + sample = random.uniform(key1, (batch_size, num_channels, height, width)) + + return sample, key2 + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = jnp.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + return jnp.transpose(sample, (3, 0, 1, 2)) + + def get_scheduler_config(self): + raise NotImplementedError + + def dummy_model(self): + def model(sample, t, *args): + return sample * t / (t + 1) + + return model + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + + +@require_flax +class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDPMScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "variance_type": "fixed_small", + "clip_sample": True, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [1, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_variance_type(self): + for variance in ["fixed_small", "fixed_large", "other"]: + self.check_over_configs(variance_type=variance) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_time_indices(self): + for t in [0, 500, 999]: + self.check_over_forward(time_step=t) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + key1, key2 = random.split(random.PRNGKey(0)) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + output = scheduler.step(state, residual, t, sample, key1) + pred_prev_sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + # if t > 0: + # noise = self.dummy_sample_deter + # variance = scheduler.get_variance(t) ** (0.5) * noise + # + # sample = pred_prev_sample + variance + sample = pred_prev_sample + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 257.28717) < 1.5e-2 + assert abs(result_mean - 0.33500) < 2e-5 + else: + assert abs(result_sum - 257.33148) < 1e-2 + assert abs(result_mean - 0.335057) < 1e-3 + + +@require_flax +class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDIMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + key1, key2 = random.split(random.PRNGKey(0)) + + num_inference_steps = 10 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + state = scheduler.set_timesteps(state, num_inference_steps) + + for t in state.timesteps: + residual = model(sample, t) + output = scheduler.step(state, residual, t, sample) + sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + return sample + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 500, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() + + def test_steps_trailing(self): + self.check_over_configs(timestep_spacing="trailing") + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(timestep_spacing="trailing") + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([999, 799, 599, 399, 199])).all() + + def test_steps_leading(self): + self.check_over_configs(timestep_spacing="leading") + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(timestep_spacing="leading") + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([800, 600, 400, 200, 0])).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 10, 49]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 172.0067) < 1e-2 + assert abs(result_mean - 0.223967) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 149.82944) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + else: + assert abs(result_sum - 149.8295) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + pass + # FIXME: both result_sum and result_mean are nan on TPU + # assert jnp.isnan(result_sum) + # assert jnp.isnan(result_mean) + else: + assert abs(result_sum - 149.0784) < 1e-2 + assert abs(result_mean - 0.1941) < 1e-3 + + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + +@require_flax +class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxPNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + state = state.replace(ets=dummy_past_residuals[:]) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + new_state = new_state.replace(ets=dummy_past_residuals[:]) + + (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + pass + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residual (must be after setting timesteps) + new_state.replace(ets=dummy_past_residuals[:]) + + output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + for i, t in enumerate(state.prk_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_prk(state, residual, t, sample) + + for i, t in enumerate(state.plms_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_plms(state, residual, t, sample) + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + state = state.replace(ets=dummy_past_residuals[:]) + + output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 10, shape=()) + assert jnp.equal( + state.timesteps, + jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), + ).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps) + + def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(state.prk_timesteps[:2]): + sample, state = scheduler.step_prk(state, residual, t, sample) + + def test_inference_plms_no_past_residuals(self): + with self.assertRaises(ValueError): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 198.1275) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + else: + assert abs(result_sum - 198.1318) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 8e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9466) < 1e-2 + assert abs(result_mean - 0.24342) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 8e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9482) < 1e-2 + assert abs(result_mean - 0.2434) < 1e-3 diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py new file mode 100644 index 00000000..821adcfe --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py @@ -0,0 +1,104 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import jax.numpy as jnp +from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler +import os +from maxdiffusion import max_logging +import torch +import unittest +from absl.testing import absltest +import numpy as np + + +class rfTest(unittest.TestCase): + + def test_rf_steps(self): + # --- Simulation Parameters --- + latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width) + inference_steps_count = 5 # Number of steps for the denoising process + + # --- Run the Simulation --- + max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---") + + seed = 42 + device = "cpu" + max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}") + + generator = torch.Generator(device=device).manual_seed(seed) + + # 1. Instantiate the scheduler + config = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": None, + "base_resolution": None, + "sampler": "LinearQuadratic", + } + flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config) + + # 2. Create and set initial state for the scheduler + flax_state = flax_scheduler.create_state() + flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape) + max_logging.log("\nScheduler initialized.") + max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}") + + # 3. Prepare the initial noisy latent sample + # In a real scenario, this would typically be pure random noise (e.g., N(0,1)) + # For simulation, we'll generate it. + + sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) + max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}") + + # 4. Simulate the denoising loop + max_logging.log("\nStarting denoising loop:") + for i, t in enumerate(flax_state.timesteps): + max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}") + + # Simulate model_output (e.g., noise prediction from a UNet) + model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) + + # Call the scheduler's step function + scheduler_output = flax_scheduler.step( + state=flax_state, + model_output=model_output, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + + sample = scheduler_output.prev_sample # Update the sample for the next step + flax_state = scheduler_output.state # Update the state for the next step + + # Compare with pytorch implementation + base_dir = os.path.dirname(__file__) + ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref") + ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy") + if os.path.exists(ref_filename): + pt_sample = np.load(ref_filename) + torch.testing.assert_close(np.array(sample), pt_sample) + else: + max_logging.log(f"Warning: Reference file not found: {ref_filename}") + + max_logging.log("\nDenoising loop completed.") + max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}") + max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}") + + max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.") + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py new file mode 100644 index 00000000..657a5bb8 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py @@ -0,0 +1,658 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/tests/schedulers/test_scheduler_unipc.py + +import tempfile + +import torch +import jax.numpy as jnp +from typing import Dict, List, Tuple + +from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import ( + FlaxUniPCMultistepScheduler, +) +from maxdiffusion import FlaxDPMSolverMultistepScheduler + +from .test_scheduler_flax import FlaxSchedulerCommonTest + + +class FlaxUniPCMultistepSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxUniPCMultistepScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + sample = torch.rand((batch_size, num_channels, height, width)) + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample = jnp.asarray(sample) + return jax_sample + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + "solver_type": "bh2", + "final_sigmas_type": "sigma_min", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + t = output_state.timesteps[i] + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + + # What is this doing? + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(new_scheduler, "set_timesteps"): + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + + t = output_state.timesteps[i] + + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + def full_loop(self, scheduler=None, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + if scheduler is None: + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + else: + state = scheduler.create_state() # Ensure state is fresh for the loop + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): + residual = model(sample, t) + + # scheduler.step in common test receives state, residual, t, sample + step_output = scheduler.step( + state=state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + sample = step_output.prev_sample + state = step_output.state # Update state for next iteration + + return sample + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() # Create initial state + + sample = self.dummy_sample # Get sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # Copy over dummy past residuals (must be done after set_timesteps) + dummy_past_model_outputs = [ + 0.2 * sample, + 0.15 * sample, + 0.10 * sample, + ] + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + + time_step_0 = state.timesteps[5] + time_step_1 = state.timesteps[6] + + output_0 = scheduler.step(state, residual, time_step_0, sample).prev_sample + output_1 = scheduler.step(state, residual, time_step_1, sample).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler_config = self.get_scheduler_config() + scheduler_1 = FlaxUniPCMultistepScheduler(**scheduler_config) + sample_1 = self.full_loop(scheduler=scheduler_1) + result_mean_1 = jnp.mean(jnp.abs(sample_1)) + + assert abs(result_mean_1.item() - 0.2464) < 1e-3 + + scheduler_2 = FlaxUniPCMultistepScheduler(**scheduler_config) # New instance + sample_2 = self.full_loop(scheduler=scheduler_2) + result_mean_2 = jnp.mean(jnp.abs(sample_2)) + + self.assertTrue(jnp.allclose(result_mean_1, result_mean_2, atol=1e-3)) # Check consistency + + assert abs(result_mean_2.item() - 0.2464) < 1e-3 + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + with self.assertRaises(NotImplementedError): + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + solver_order=order, + solver_type=solver_type, + ) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_rescale_betas_zero_snr(self): + for rescale_zero_terminal_snr in [True, False]: + self.check_over_configs(rescale_zero_terminal_snr=rescale_zero_terminal_snr) + + def test_solver_order_and_type(self): + for solver_type in ["bh1", "bh2"]: + for order in [1, 2, 3]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + assert not jnp.any(jnp.isnan(sample)), "Samples have nan numbers" + + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(time_step=0, num_inference_steps=num_inference_steps) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2464) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2925) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1966) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: + scheduler_config = self.get_scheduler_config( + thresholding=False, + dynamic_thresholding_ratio=0, + prediction_type=prediction_type, + solver_order=order, + solver_type=solver_type, + ) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.astype(jnp.bfloat16) + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + # sample is casted to fp32 inside step and output should be fp32. + self.assertEqual(sample.dtype, jnp.float32) + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + t_start_index = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index:] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" + assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" + + +class FlaxUniPCMultistepScheduler1DTest(FlaxUniPCMultistepSchedulerTest): + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + width = 8 + + torch_sample = torch.rand((batch_size, num_channels, width)) + jax_sample = jnp.asarray(torch_sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + jax_sample = jnp.asarray(sample) + return jax_sample + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler = FlaxUniPCMultistepScheduler(**self.get_scheduler_config()) + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + scheduler = FlaxDPMSolverMultistepScheduler.from_config(scheduler.config) + scheduler = FlaxUniPCMultistepScheduler.from_config(scheduler.config) + + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2898) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1944) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + t_start_index = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index:] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" + assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}" + + def test_beta_sigmas(self): + # self.check_over_configs(use_beta_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_beta_sigmas=True) + + def test_exponential_sigmas(self): + # self.check_over_configs(use_exponential_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_exponential_sigmas=True) diff --git a/src/maxdiffusion/tests/ltx_transformer_step_test.py b/src/maxdiffusion/tests/ltx_transformer_step_test.py index 9398c915..c868bd95 100644 --- a/src/maxdiffusion/tests/ltx_transformer_step_test.py +++ b/src/maxdiffusion/tests/ltx_transformer_step_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import torch @@ -104,7 +104,7 @@ def test_one_step_transformer(self): devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) base_dir = os.path.dirname(__file__) - config_path = os.path.join(base_dir, "../models/ltx_video/xora_v1.2-13B-balanced-128.json") + config_path = os.path.join(base_dir, "../models/ltx_video/ltxv-13B.json") with open(config_path, "r") as f: model_config = json.load(f) diff --git a/src/maxdiffusion/tests/maxdiffusion_utils_test.py b/src/maxdiffusion/tests/maxdiffusion_utils_test.py index 23dd3ee3..4b29d365 100644 --- a/src/maxdiffusion/tests/maxdiffusion_utils_test.py +++ b/src/maxdiffusion/tests/maxdiffusion_utils_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest diff --git a/src/maxdiffusion/tests/text_encoders_test.py b/src/maxdiffusion/tests/text_encoders_test.py index e7d3d6dd..c91bca9a 100644 --- a/src/maxdiffusion/tests/text_encoders_test.py +++ b/src/maxdiffusion/tests/text_encoders_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest @@ -36,7 +36,6 @@ def setUp(self): @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_t5_text_encoder(self): - text_encoder = FlaxT5EncoderModel.from_pretrained("ariG23498/t5-v1-1-xxl-flax") tokenizer_2 = T5TokenizerFast.from_pretrained("ariG23498/t5-v1-1-xxl-flax") @@ -47,7 +46,6 @@ def test_flux_t5_text_encoder(self): @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_clip_text_encoder(self): - text_encoder = FlaxCLIPTextModel.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="text_encoder", from_pt=True, dtype="bfloat16" ) diff --git a/src/maxdiffusion/tests/train_smoke_test.py b/src/maxdiffusion/tests/train_smoke_test.py index a7d0f4b8..f5f6df00 100644 --- a/src/maxdiffusion/tests/train_smoke_test.py +++ b/src/maxdiffusion/tests/train_smoke_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test """ import os diff --git a/src/maxdiffusion/tests/unet_test.py b/src/maxdiffusion/tests/unet_test.py index 562fb5a3..0bbf706f 100644 --- a/src/maxdiffusion/tests/unet_test.py +++ b/src/maxdiffusion/tests/unet_test.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" """ Smoke test """ import os diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index e3a46b10..17e9b211 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import unittest @@ -38,7 +38,6 @@ def setUp(self): @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_flux_vae(self): - img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png") base_image = np.array(Image.open(img_url)).astype(np.uint8) img_min = np.min(base_image) diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index 81a38670..a0a529f1 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -1,15 +1,15 @@ """ - Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Copyright 2025 Google LLC +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import unittest from unittest.mock import patch, MagicMock @@ -19,6 +19,7 @@ from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 from maxdiffusion.pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1 + class WanCheckpointer2_1Test(unittest.TestCase): """Tests for WAN 2.1 checkpointer.""" @@ -237,6 +238,7 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, m self.assertEqual(opt_state["learning_rate"], 0.002) self.assertEqual(step, 1) + class WanCheckpointerI2V_2_1Test(unittest.TestCase): """Tests for WAN 2.1 I2V checkpointer.""" @@ -324,6 +326,7 @@ def test_load_checkpoint_with_optimizer(self, mock_from_checkpoint, mock_create_ self.assertEqual(opt_state["learning_rate"], 0.001) self.assertEqual(step, 1) + class WanCheckpointerI2V_2_2Test(unittest.TestCase): """Tests for WAN 2.2 I2V checkpointer.""" @@ -447,6 +450,7 @@ def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline_i2 self.assertEqual(opt_state["learning_rate"], 0.002) self.assertEqual(step, 1) + class WanCheckpointerEdgeCasesTest(unittest.TestCase): """Tests for edge cases and error handling.""" diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index c1044cc3..4d54525d 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import jax @@ -65,7 +65,6 @@ def setUp(self): devices_array = create_device_mesh(config) self.mesh = Mesh(devices_array, config.mesh_axes) - def test_rotary_pos_embed(self): batch_size = 1 channels = 16 @@ -126,9 +125,7 @@ def test_wan_time_text_embedding(self): encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( - dummy_timestep, dummy_encoder_hidden_states - ) + temb, timestep_proj, encoder_hidden_states, _, _ = layer(dummy_timestep, dummy_encoder_hidden_states) assert temb.shape == (batch_size, dim) assert timestep_proj.shape == (batch_size, time_proj_dim) assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) @@ -198,12 +195,7 @@ def test_wan_block(self): def test_wan_attention(self): for attention_kernel in ["flash", "tokamax_flash"]: pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - f"attention={attention_kernel}" - ], - unittest=True + [None, os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), f"attention={attention_kernel}"], unittest=True ) config = pyconfig.config batch_size = 1 @@ -286,7 +278,9 @@ def test_wan_model(self): batch_size = 1 num_layers = 1 with nn_partitioning.axis_rules(config.logical_axis_rules): - wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) + wan_model = WanModel( + rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers + ) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) @@ -400,6 +394,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_config.weight_quantization_calibration_method = "fixed,-224,224" mock_config.act_quantization_calibration_method = "fixed,-224,224" mock_config.bwd_quantization_calibration_method = "absmax" + mock_config.global_batch_size_to_train_on = 32 mock_model = Mock(spec=WanModel) mock_pipeline = Mock() diff --git a/src/maxdiffusion/tests/wan_vace_transformer_test.py b/src/maxdiffusion/tests/wan_vace_transformer_test.py index 9864e64c..05b04f76 100644 --- a/src/maxdiffusion/tests/wan_vace_transformer_test.py +++ b/src/maxdiffusion/tests/wan_vace_transformer_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import jax @@ -43,6 +43,7 @@ class WanVaceTransformerTest(unittest.TestCase): + def test_wan_vace_block_returns_the_correct_shape(self): key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -117,5 +118,6 @@ def test_wan_vace_block_returns_the_correct_shape(self): assert conditioning_states.shape == dummy_hidden_states.shape assert control_hidden_states.shape == dummy_hidden_states.shape + if __name__ == "__main__": absltest.main() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index b2ffbc3b..73db7173 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import functools @@ -211,7 +211,6 @@ def test_wanrms_norm(self): assert np.allclose(output_np, torch_output_np) is True def test_zero_padded_conv(self): - key = jax.random.key(0) rngs = nnx.Rngs(key) @@ -267,7 +266,7 @@ def test_wan_resample(self): # channels is always last here input_shape = (batch, t, h, w, dim) dummy_input = jnp.ones(input_shape) - output = wan_resample(dummy_input) + output, _, _ = wan_resample(dummy_input) assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): @@ -348,7 +347,7 @@ def test_wan_residual(self): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) + dummy_output, _, _ = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape # --- Test Case 1: different in/out dim --- in_dim = 96 @@ -357,7 +356,7 @@ def test_wan_residual(self): wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) + dummy_output, _, _ = wan_residual_block(dummy_input) assert dummy_output.shape == expected_output_shape def test_wan_attention(self): @@ -372,7 +371,7 @@ def test_wan_attention(self): with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) dummy_input = jnp.ones(input_shape) - output = wan_attention(dummy_input) + output, _, _ = wan_attention(dummy_input) assert output.shape == input_shape def test_wan_midblock(self): @@ -397,7 +396,7 @@ def test_wan_midblock(self): with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) dummy_input = jnp.ones(input_shape) - output = wan_midblock(dummy_input) + output, _, _ = wan_midblock(dummy_input) assert output.shape == input_shape def test_wan_decode(self): @@ -523,11 +522,11 @@ def vae_encode(video, wan_vae, vae_cache, key): params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) wan_vae = nnx.merge(graphdef, params) - p_vae_encode = jax.jit(functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key)) + p_vae_encode = functools.partial(vae_encode, wan_vae=wan_vae, vae_cache=vae_cache, key=key) original_video_shape = original_video.shape latent = p_vae_encode(original_video) - jitted_decode = jax.jit(functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False)) + jitted_decode = functools.partial(wan_vae.decode, feat_cache=vae_cache, return_dict=False) video = jitted_decode(latent)[0] video = jnp.transpose(video, (0, 4, 1, 2, 3)) assert video.shape == original_video_shape diff --git a/src/maxdiffusion/tpu_utils.py b/src/maxdiffusion/tpu_utils.py index 9ea03e7c..5697f60c 100644 --- a/src/maxdiffusion/tpu_utils.py +++ b/src/maxdiffusion/tpu_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax diff --git a/src/maxdiffusion/train.py b/src/maxdiffusion/train.py index 60657e0b..1bfcc942 100644 --- a/src/maxdiffusion/train.py +++ b/src/maxdiffusion/train.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index 05cdae44..e341ae1f 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/train_sdxl.py b/src/maxdiffusion/train_sdxl.py index 60170a85..ad11c1e4 100644 --- a/src/maxdiffusion/train_sdxl.py +++ b/src/maxdiffusion/train_sdxl.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 79e65e99..8db92a40 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import numpy as np import jax diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index fea15720..d272ca23 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from typing import Sequence @@ -35,7 +35,10 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.config validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") - flax.config.update("flax_always_shard_variable", False) + try: + flax.config.update("flax_always_shard_variable", False) + except LookupError: + pass train(config) diff --git a/src/maxdiffusion/trainers/__init__.py b/src/maxdiffusion/trainers/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/trainers/__init__.py +++ b/src/maxdiffusion/trainers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index a9f17adc..7bb2e26b 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from abc import abstractmethod import time diff --git a/src/maxdiffusion/trainers/dreambooth_trainer.py b/src/maxdiffusion/trainers/dreambooth_trainer.py index 40a40190..a2bd8991 100644 --- a/src/maxdiffusion/trainers/dreambooth_trainer.py +++ b/src/maxdiffusion/trainers/dreambooth_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" from pathlib import Path import time @@ -116,7 +116,6 @@ def get_data_shardings(self): return data_sharding def load_dataset(self, pipeline, params, train_states): - return make_dreambooth_train_iterator( self.config, self.mesh, @@ -183,7 +182,6 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da return p_train_step def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, learning_rate_scheduler): - writer = max_utils.initialize_summary_writer(self.config) unet_state = train_states["unet_state"] text_encoder_state = train_states["text_encoder_state"] @@ -265,7 +263,6 @@ def _train_step(unet_state, text_encoder_state, batch, train_rng, config, pipeli state_params = {"text_encoder": text_encoder_state.params, "unet": unet_state.params} def compute_loss(state_params): - encoder_hidden_states = encode(input_ids, pipeline.text_encoder, state_params["text_encoder"]) # Sample noise that we'll add to the latents diff --git a/src/maxdiffusion/trainers/flux_trainer.py b/src/maxdiffusion/trainers/flux_trainer.py index 74b4f259..6cda5682 100644 --- a/src/maxdiffusion/trainers/flux_trainer.py +++ b/src/maxdiffusion/trainers/flux_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os from functools import partial @@ -76,7 +76,6 @@ def calculate_tflops(self, pipeline): return per_device_tflops def start_training(self): - # Hook # self.pre_training_steps() # Load checkpoint - will load or create states @@ -247,6 +246,18 @@ def load_dataset(self, pipeline, params, train_states): total_train_batch_size = self.total_train_batch_size mesh = self.mesh + # If using synthetic data + if config.dataset_type == "synthetic": + return make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + total_train_batch_size, + pipeline=pipeline, # Pass pipeline to extract dimensions + is_training=True, + ) + encode_fn = partial( pipeline.encode_prompt, clip_tokenizer=pipeline.clip_tokenizer, @@ -314,7 +325,6 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da return p_train_step def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): - writer = max_utils.initialize_summary_writer(self.config) flux_state = train_states[FLUX_STATE_KEY] num_model_parameters = max_utils.calculate_num_params_from_pytree(flux_state.params) diff --git a/src/maxdiffusion/trainers/sdxl_trainer.py b/src/maxdiffusion/trainers/sdxl_trainer.py index a68cc617..fc442117 100644 --- a/src/maxdiffusion/trainers/sdxl_trainer.py +++ b/src/maxdiffusion/trainers/sdxl_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os from functools import partial @@ -176,7 +176,6 @@ def prepare_sample(features): return data_iterator def compile_train_step(self, pipeline, params, train_states, state_shardings, data_shardings): - self.rng, train_rngs = jax.random.split(self.rng) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): p_train_step = jax.jit( @@ -208,7 +207,6 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da return p_train_step def training_loop(self, p_train_step, pipeline, params, train_states, data_iterator, unet_learning_rate_scheduler): - writer = max_utils.initialize_summary_writer(self.config) writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True) writer_thread.start() diff --git a/src/maxdiffusion/trainers/stable_diffusion_trainer.py b/src/maxdiffusion/trainers/stable_diffusion_trainer.py index 5844df3d..a89c22ac 100644 --- a/src/maxdiffusion/trainers/stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/stable_diffusion_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import sys diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index f23836a5..a3a1fee6 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import os import datetime @@ -20,6 +20,7 @@ import pprint import numpy as np import threading +from contextlib import nullcontext from concurrent.futures import ThreadPoolExecutor import tensorflow as tf import jax.numpy as jnp @@ -105,7 +106,6 @@ def create_scheduler(self): @staticmethod def calculate_tflops(pipeline): - maxdiffusion_config = pipeline.config # Model configuration height = pipeline.config.height @@ -164,7 +164,18 @@ def get_eval_data_shardings(self, mesh): data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": data_sharding} return data_sharding - def load_dataset(self, mesh, is_training=True): + def load_dataset(self, mesh, pipeline=None, is_training=True): + """ + Load dataset - supports both real tfrecord and synthetic data. + + Args: + mesh: JAX mesh for sharding + pipeline: Optional WAN pipeline to extract dimensions from (for synthetic data) + is_training: Whether this is for training or evaluation + + Returns: + Data iterator + """ # Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314 # Image pre-training - txt2img 256px # Image-video joint training - stage 1. 256 px images and 192px 5 sec videos at fps=16 @@ -173,6 +184,21 @@ def load_dataset(self, mesh, is_training=True): # prompt embeds shape: (1, 512, 4096) # For now, we will pass the same latents over and over # TODO - create a dataset + + config = self.config + + # If using synthetic data + if config.dataset_type == "synthetic": + return make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + config.global_batch_size_to_load, + pipeline=pipeline, # Pass pipeline to extract dimensions + is_training=is_training, + ) + config = self.config if config.dataset_type != "tfrecord" and not config.cache_latents_text_encoder_outputs: raise ValueError( @@ -210,8 +236,8 @@ def prepare_sample_eval(features): return data_iterator def start_training(self): - - pipeline, opt_state, step = self.checkpointer.load_checkpoint() + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + pipeline, opt_state, step = self.checkpointer.load_checkpoint() restore_args = {} if opt_state and step: restore_args = {"opt_state": opt_state, "step": step} @@ -226,7 +252,7 @@ def start_training(self): del pipeline.vae_cache mesh = pipeline.mesh - train_data_iterator = self.load_dataset(mesh, is_training=True) + train_data_iterator = self.load_dataset(mesh, pipeline=pipeline, is_training=True) # Load FlowMatch scheduler scheduler, scheduler_state = self.create_scheduler() @@ -309,7 +335,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60) max_logging.log(pretty_string) max_logging.log("------------------------------------------------") - max_utils.delete_pytree(params) + if self.config.hardware != "gpu": + max_utils.delete_pytree(params) data_shardings = self.get_data_shardings(mesh) eval_data_shardings = self.get_eval_data_shardings(mesh) @@ -364,10 +391,19 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data if self.config.enable_profiler and step == first_profiling_step: max_utils.activate_profiler(self.config) start_step_time = datetime.datetime.now() + + # Designate the context parallel axis for sharding + if self.config.attention == "cudnn_flash_te": + from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error + + shard_guard = global_shard_guard(MeshResource(cp_resource="context")) + else: + shard_guard = nullcontext() + next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) - with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): + with jax.profiler.StepTraceAnnotation( + "train", step_num=step + ), pipeline.mesh, shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules): state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) train_metric["scalar"]["learning/loss"].block_until_ready() last_step_completion = datetime.datetime.now() diff --git a/src/maxdiffusion/utils/deprecation_utils.py b/src/maxdiffusion/utils/deprecation_utils.py index bd2f6e35..a7077ed7 100644 --- a/src/maxdiffusion/utils/deprecation_utils.py +++ b/src/maxdiffusion/utils/deprecation_utils.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import inspect import warnings from typing import Any, Dict, Optional, Union diff --git a/src/maxdiffusion/utils/export_utils.py b/src/maxdiffusion/utils/export_utils.py index c540f5a9..fa394129 100644 --- a/src/maxdiffusion/utils/export_utils.py +++ b/src/maxdiffusion/utils/export_utils.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import io import random import struct diff --git a/src/maxdiffusion/utils/import_utils.py b/src/maxdiffusion/utils/import_utils.py index d83596e8..05ef72ec 100644 --- a/src/maxdiffusion/utils/import_utils.py +++ b/src/maxdiffusion/utils/import_utils.py @@ -512,29 +512,27 @@ def is_peft_available(): """ -BACKENDS_MAPPING = OrderedDict( - [ - ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), - ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), - ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), - ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), - ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), - ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), - ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), - ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), - ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), - ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), - ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), - ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), - ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), - ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), - ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), - ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), - ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), - ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), - ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), - ] -) +BACKENDS_MAPPING = OrderedDict([ + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), + ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)), + ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)), + ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)), + ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)), + ("tensorboard", (is_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), + ("compel", (is_compel_available, COMPEL_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), + ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), + ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), +]) def requires_backends(obj, backends): diff --git a/src/maxdiffusion/utils/loading_utils.py b/src/maxdiffusion/utils/loading_utils.py index 85bddb87..735d2261 100644 --- a/src/maxdiffusion/utils/loading_utils.py +++ b/src/maxdiffusion/utils/loading_utils.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import os from typing import Callable, List, Optional, Union diff --git a/src/maxdiffusion/utils/logging.py b/src/maxdiffusion/utils/logging.py index b9013a95..2fe7d87d 100644 --- a/src/maxdiffusion/utils/logging.py +++ b/src/maxdiffusion/utils/logging.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Logging utilities.""" +"""Logging utilities.""" import logging import os diff --git a/src/maxdiffusion/utils/pil_utils.py b/src/maxdiffusion/utils/pil_utils.py index cb44e025..86d07c66 100644 --- a/src/maxdiffusion/utils/pil_utils.py +++ b/src/maxdiffusion/utils/pil_utils.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ from typing import List import PIL.Image diff --git a/src/maxdiffusion/utils/testing_utils.py b/src/maxdiffusion/utils/testing_utils.py index a5e8aeae..55be62ac 100644 --- a/src/maxdiffusion/utils/testing_utils.py +++ b/src/maxdiffusion/utils/testing_utils.py @@ -1,18 +1,19 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ import functools import importlib import inspect diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index b392d39a..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ diff --git a/tests/models/__init__.py b/tests/models/__init__.py deleted file mode 100644 index b392d39a..00000000 --- a/tests/models/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ diff --git a/tests/models/test_modeling_common_flax.py b/tests/models/test_modeling_common_flax.py deleted file mode 100644 index 0fa55dcf..00000000 --- a/tests/models/test_modeling_common_flax.py +++ /dev/null @@ -1,82 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import inspect - -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import require_flax - - -if is_flax_available(): - import jax - - -@require_flax -class FlaxModelTesterMixin: - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) - - output = model.apply(variables, inputs_dict["sample"]) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["norm_num_groups"] = 16 - init_dict["block_out_channels"] = (16, 32) - - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) - - output = model.apply(variables, inputs_dict["sample"]) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_deprecated_kwargs(self): - has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters - has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 - - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" - " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) - - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" - f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" - " from `_deprecated_kwargs = []`" - ) diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py deleted file mode 100644 index ed1c8d39..00000000 --- a/tests/models/test_models_unet_2d_flax.py +++ /dev/null @@ -1,119 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import gc -import unittest - -from maxdiffusion import FlaxUNet2DConditionModel -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import load_hf_numpy, require_flax, slow -from parameterized import parameterized - - -if is_flax_available(): - import jax - import jax.numpy as jnp - - -@slow -@require_flax -class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): - def get_file_format(self, seed, shape): - return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" - - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return image - - def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - revision = "bf16" if fp16 else None - - model, params = FlaxUNet2DConditionModel.from_pretrained( - model_id, subfolder="unet", dtype=dtype, revision=revision - ) - return model, params - - def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return hidden_states - - @parameterized.expand( - [ - # fmt: off - [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], - [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], - [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], - [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], - # fmt: on - ] - ) - def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) - latents = self.get_latents(seed, fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) - - @parameterized.expand( - [ - # fmt: off - [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], - [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], - [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], - [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], - # fmt: on - ] - ) - def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) - latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/tests/models/test_models_vae_flax.py b/tests/models/test_models_vae_flax.py deleted file mode 100644 index b00bd6e9..00000000 --- a/tests/models/test_models_vae_flax.py +++ /dev/null @@ -1,55 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import unittest - -from maxdiffusion import FlaxAutoencoderKL -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import require_flax - -from .test_modeling_common_flax import FlaxModelTesterMixin - - -if is_flax_available(): - import jax - - -@require_flax -class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): - model_class = FlaxAutoencoderKL - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - prng_key = jax.random.PRNGKey(0) - image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) - - return {"sample": image, "prng_key": prng_key} - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict diff --git a/tests/schedulers/__init__.py b/tests/schedulers/__init__.py deleted file mode 100644 index b392d39a..00000000 --- a/tests/schedulers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py deleted file mode 100644 index eab5cb91..00000000 --- a/tests/schedulers/test_scheduler_flax.py +++ /dev/null @@ -1,939 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import tempfile -import unittest -from typing import Dict, List, Tuple - -from maxdiffusion import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import require_flax - - -if is_flax_available(): - import jax - import jax.numpy as jnp - from jax import random - - jax_device = jax.default_backend() - - -@require_flax -class FlaxSchedulerCommonTest(unittest.TestCase): - scheduler_classes = () - forward_default_kwargs = () - - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - key1, key2 = random.split(random.PRNGKey(0)) - sample = random.uniform(key1, (batch_size, num_channels, height, width)) - - return sample, key2 - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = jnp.arange(num_elems) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - return jnp.transpose(sample, (3, 0, 1, 2)) - - def get_scheduler_config(self): - raise NotImplementedError - - def dummy_model(self): - def model(sample, t, *args): - return sample * t / (t + 1) - - return model - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, key = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, key = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_deprecated_kwargs(self): - for scheduler_class in self.scheduler_classes: - has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters - has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 - - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" - " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) - - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" - f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" - " deprecated argument from `_deprecated_kwargs = []`" - ) - - -@require_flax -class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDPMScheduler,) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - "variance_type": "fixed_small", - "clip_sample": True, - } - - config.update(**kwargs) - return config - - def test_timesteps(self): - for timesteps in [1, 5, 100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_variance_type(self): - for variance in ["fixed_small", "fixed_large", "other"]: - self.check_over_configs(variance_type=variance) - - def test_clip_sample(self): - for clip_sample in [True, False]: - self.check_over_configs(clip_sample=clip_sample) - - def test_time_indices(self): - for t in [0, 500, 999]: - self.check_over_forward(time_step=t) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_trained_timesteps = len(scheduler) - - model = self.dummy_model() - sample = self.dummy_sample_deter - key1, key2 = random.split(random.PRNGKey(0)) - - for t in reversed(range(num_trained_timesteps)): - # 1. predict noise residual - residual = model(sample, t) - - # 2. predict previous mean of sample x_t-1 - output = scheduler.step(state, residual, t, sample, key1) - pred_prev_sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) - - # if t > 0: - # noise = self.dummy_sample_deter - # variance = scheduler.get_variance(t) ** (0.5) * noise - # - # sample = pred_prev_sample + variance - sample = pred_prev_sample - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 257.28717) < 1.5e-2 - assert abs(result_mean - 0.33500) < 2e-5 - else: - assert abs(result_sum - 257.33148) < 1e-2 - assert abs(result_mean - 0.335057) < 1e-3 - - -@require_flax -class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDIMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } - - config.update(**kwargs) - return config - - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - key1, key2 = random.split(random.PRNGKey(0)) - - num_inference_steps = 10 - - model = self.dummy_model() - sample = self.dummy_sample_deter - - state = scheduler.set_timesteps(state, num_inference_steps) - - for t in state.timesteps: - residual = model(sample, t) - output = scheduler.step(state, residual, t, sample) - sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) - - return sample - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 500, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() - - def test_steps_trailing(self): - self.check_over_configs(timestep_spacing="trailing") - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(timestep_spacing="trailing") - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([999, 799, 599, 399, 199])).all() - - def test_steps_leading(self): - self.check_over_configs(timestep_spacing="leading") - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(timestep_spacing="leading") - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([800, 600, 400, 200, 0])).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 10, 49]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - sample = self.full_loop() - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum - 172.0067) < 1e-2 - assert abs(result_mean - 0.223967) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 149.82944) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - else: - assert abs(result_sum - 149.8295) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - pass - # FIXME: both result_sum and result_mean are nan on TPU - # assert jnp.isnan(result_sum) - # assert jnp.isnan(result_mean) - else: - assert abs(result_sum - 149.0784) < 1e-2 - assert abs(result_mean - 0.1941) < 1e-3 - - def test_prediction_type(self): - for prediction_type in ["epsilon", "sample", "v_prediction"]: - self.check_over_configs(prediction_type=prediction_type) - - -@require_flax -class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxPNDMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } - - config.update(**kwargs) - return config - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - state = state.replace(ets=dummy_past_residuals[:]) - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - new_state = new_state.replace(ets=dummy_past_residuals[:]) - - (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - pass - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residuals (must be after setting timesteps) - scheduler.ets = dummy_past_residuals[:] - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - # copy over dummy past residuals - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residual (must be after setting timesteps) - new_state.replace(ets=dummy_past_residuals[:]) - - output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - for i, t in enumerate(state.prk_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_prk(state, residual, t, sample) - - for i, t in enumerate(state.plms_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_plms(state, residual, t, sample) - - return sample - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - # copy over dummy past residuals (must be done after set_timesteps) - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - state = state.replace(ets=dummy_past_residuals[:]) - - output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 10, shape=()) - assert jnp.equal( - state.timesteps, - jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), - ).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 5, 10]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward(num_inference_steps=num_inference_steps) - - def test_pow_of_3_inference_steps(self): - # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 - num_inference_steps = 27 - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # before power of 3 fix, would error on first step, so we only need to do two - for i, t in enumerate(state.prk_timesteps[:2]): - sample, state = scheduler.step_prk(state, residual, t, sample) - - def test_inference_plms_no_past_residuals(self): - with self.assertRaises(ValueError): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 198.1275) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - else: - assert abs(result_sum - 198.1318) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 8e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9466) < 1e-2 - assert abs(result_mean - 0.24342) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 8e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9482) < 1e-2 - assert abs(result_mean - 0.2434) < 1e-3 diff --git a/tests/schedulers/test_scheduler_rf.py b/tests/schedulers/test_scheduler_rf.py deleted file mode 100644 index 1d23880f..00000000 --- a/tests/schedulers/test_scheduler_rf.py +++ /dev/null @@ -1,99 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ -import jax.numpy as jnp -from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler -import os -from maxdiffusion import max_logging -import torch -import unittest -from absl.testing import absltest -import numpy as np - - - -class rfTest(unittest.TestCase): - - def test_rf_steps(self): - # --- Simulation Parameters --- - latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width) - inference_steps_count = 5 # Number of steps for the denoising process - - # --- Run the Simulation --- - max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---") - - seed = 42 - device = 'cpu' - max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}") - - generator = torch.Generator(device=device).manual_seed(seed) - - # 1. Instantiate the scheduler - config = {'_class_name': 'RectifiedFlowScheduler', '_diffusers_version': '0.25.1', 'num_train_timesteps': 1000, 'shifting': None, 'base_resolution': None, 'sampler': 'LinearQuadratic'} - flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config) - - # 2. Create and set initial state for the scheduler - flax_state = flax_scheduler.create_state() - flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape) - max_logging.log("\nScheduler initialized.") - max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}") - - # 3. Prepare the initial noisy latent sample - # In a real scenario, this would typically be pure random noise (e.g., N(0,1)) - # For simulation, we'll generate it. - - sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) - max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}") - - # 4. Simulate the denoising loop - max_logging.log("\nStarting denoising loop:") - for i, t in enumerate(flax_state.timesteps): - max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}") - - # Simulate model_output (e.g., noise prediction from a UNet) - model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) - - # Call the scheduler's step function - scheduler_output = flax_scheduler.step( - state=flax_state, - model_output=model_output, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=sample, - return_dict=True # Return a SchedulerOutput dataclass - ) - - sample = scheduler_output.prev_sample # Update the sample for the next step - flax_state = scheduler_output.state # Update the state for the next step - - # Compare with pytorch implementation - base_dir = os.path.dirname(__file__) - ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref") - ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy") - if os.path.exists(ref_filename): - pt_sample = np.load(ref_filename) - torch.testing.assert_close(np.array(sample), pt_sample) - else: - max_logging.log(f"Warning: Reference file not found: {ref_filename}") - - - max_logging.log("\nDenoising loop completed.") - max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}") - max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}") - - max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.") - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py deleted file mode 100644 index d401f54f..00000000 --- a/tests/schedulers/test_scheduler_unipc.py +++ /dev/null @@ -1,680 +0,0 @@ -# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info -# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/tests/schedulers/test_scheduler_unipc.py - -import tempfile - -import torch -import jax.numpy as jnp -from typing import Dict, List, Tuple - -from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import ( - FlaxUniPCMultistepScheduler, -) -from maxdiffusion import FlaxDPMSolverMultistepScheduler - -from .test_scheduler_flax import FlaxSchedulerCommonTest - - -class FlaxUniPCMultistepSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxUniPCMultistepScheduler,) - forward_default_kwargs = (("num_inference_steps", 25),) - - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - sample = torch.rand((batch_size, num_channels, height, width)) - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_noise_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = torch.arange(num_elems).flip(-1) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - sample = sample.permute(3, 0, 1, 2) - - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = torch.arange(num_elems) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - sample = sample.permute(3, 0, 1, 2) - - jax_sample= jnp.asarray(sample) - return jax_sample - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - "solver_order": 2, - "solver_type": "bh2", - "final_sigmas_type": "sigma_min", - } - - config.update(**kwargs) - return config - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample = self.dummy_sample - residual = 0.1 * sample - dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - state = scheduler.set_timesteps( - state, num_inference_steps, sample.shape - ) - new_state = new_scheduler.set_timesteps( - new_state, num_inference_steps, sample.shape - ) - # copy over dummy past residuals - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - # Copy over dummy past residuals to new_state as well - new_state = new_state.replace(model_outputs=initial_model_outputs) - - - output_sample, output_state = sample, state - new_output_sample, new_output_state = sample, new_state - # Need to iterate through the steps as UniPC maintains history over steps - # The loop for solver_order + 1 steps is crucial for UniPC's history logic. - for i in range(time_step, time_step + scheduler.config.solver_order + 1): - # Ensure time_step + i is within the bounds of timesteps - if i >= len(output_state.timesteps): - break - t = output_state.timesteps[i] - step_output = scheduler.step( - state=output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - output_sample = step_output.prev_sample - output_state = step_output.state - - new_step_output = new_scheduler.step( - state=new_output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=new_output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - new_output_sample = new_step_output.prev_sample - new_output_state = new_step_output.state - - self.assertTrue( - jnp.allclose(output_sample, new_output_sample, atol=1e-5), - "Scheduler outputs are not identical", - ) - # Also assert that states are identical - self.assertEqual(output_state.step_index, new_output_state.step_index) - self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) - self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) - # Comparing model_outputs (history) directly: - if output_state.model_outputs is not None and new_output_state.model_outputs is not None: - for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): - self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample = self.dummy_sample - residual = 0.1 * sample - dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - state = scheduler.set_timesteps( - state, num_inference_steps, sample.shape - ) - - # copy over dummy past residuals - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - - # What is this doing? - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(new_scheduler, "set_timesteps"): - new_state = new_scheduler.set_timesteps( - new_state, num_inference_steps, sample.shape - ) - # Copy over dummy past residuals to new_state as well - new_state = new_state.replace(model_outputs=initial_model_outputs) - - - output_sample, output_state = sample, state - new_output_sample, new_output_state = sample, new_state - - # Need to iterate through the steps as UniPC maintains history over steps - # The loop for solver_order + 1 steps is crucial for UniPC's history logic. - for i in range(time_step, time_step + scheduler.config.solver_order + 1): - # Ensure time_step + i is within the bounds of timesteps - if i >= len(output_state.timesteps): - break - - t = output_state.timesteps[i] - - step_output = scheduler.step( - state=output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - **kwargs, - ) - output_sample = step_output.prev_sample - output_state = step_output.state - - new_step_output = new_scheduler.step( - state=new_output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=new_output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - **kwargs, - ) - new_output_sample = new_step_output.prev_sample - new_output_state = new_step_output.state - - self.assertTrue( - jnp.allclose(output_sample, new_output_sample, atol=1e-5), - "Scheduler outputs are not identical", - ) - # Also assert that states are identical - self.assertEqual(output_state.step_index, new_output_state.step_index) - self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) - self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) - # Comparing model_outputs (history) directly: - if output_state.model_outputs is not None and new_output_state.model_outputs is not None: - for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): - self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") - - - def full_loop(self, scheduler=None, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - if scheduler is None: - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - else: - state = scheduler.create_state() # Ensure state is fresh for the loop - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - for i, t in enumerate(state.timesteps): - residual = model(sample, t) - - # scheduler.step in common test receives state, residual, t, sample - step_output = scheduler.step( - state=state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - sample = step_output.prev_sample - state = step_output.state # Update state for next iteration - - return sample - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() # Create initial state - - sample = self.dummy_sample # Get sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif ( - num_inference_steps is not None - and not hasattr(scheduler, "set_timesteps") - ): - kwargs["num_inference_steps"] = num_inference_steps - - # Copy over dummy past residuals (must be done after set_timesteps) - dummy_past_model_outputs = [ - 0.2 * sample, - 0.15 * sample, - 0.10 * sample, - ] - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - - time_step_0 = state.timesteps[5] - time_step_1 = state.timesteps[6] - - output_0 = scheduler.step(state, residual, time_step_0, sample).prev_sample - output_1 = scheduler.step(state, residual, time_step_1, sample).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - - sample = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_switch(self): - # make sure that iterating over schedulers with same config names gives same results - # for defaults - scheduler_config = self.get_scheduler_config() - scheduler_1 = FlaxUniPCMultistepScheduler(**scheduler_config) - sample_1 = self.full_loop(scheduler=scheduler_1) - result_mean_1 = jnp.mean(jnp.abs(sample_1)) - - assert abs(result_mean_1.item() - 0.2464) < 1e-3 - - scheduler_2 = FlaxUniPCMultistepScheduler(**scheduler_config) # New instance - sample_2 = self.full_loop(scheduler=scheduler_2) - result_mean_2 = jnp.mean(jnp.abs(sample_2)) - - self.assertTrue(jnp.allclose(result_mean_1, result_mean_2, atol=1e-3)) # Check consistency - - assert abs(result_mean_2.item() - 0.2464) < 1e-3 - - def test_timesteps(self): - for timesteps in [25, 50, 100, 999, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_thresholding(self): - self.check_over_configs(thresholding=False) - for order in [1, 2, 3]: - for solver_type in ["bh1", "bh2"]: - for threshold in [0.5, 1.0, 2.0]: - for prediction_type in ["epsilon", "sample"]: - with self.assertRaises(NotImplementedError): - self.check_over_configs( - thresholding=True, - prediction_type=prediction_type, - sample_max_value=threshold, - solver_order=order, - solver_type=solver_type, - ) - - def test_prediction_type(self): - for prediction_type in ["epsilon", "v_prediction"]: - self.check_over_configs(prediction_type=prediction_type) - - def test_rescale_betas_zero_snr(self): - for rescale_zero_terminal_snr in [True, False]: - self.check_over_configs(rescale_zero_terminal_snr=rescale_zero_terminal_snr) - - def test_solver_order_and_type(self): - for solver_type in ["bh1", "bh2"]: - for order in [1, 2, 3]: - for prediction_type in ["epsilon", "sample"]: - self.check_over_configs( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - ) - sample = self.full_loop( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - ) - assert not jnp.any(jnp.isnan(sample)), "Samples have nan numbers" - - - def test_lower_order_final(self): - self.check_over_configs(lower_order_final=True) - self.check_over_configs(lower_order_final=False) - - def test_inference_steps(self): - for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: - self.check_over_forward(time_step = 0, num_inference_steps=num_inference_steps) - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2464) < 1e-3 - - def test_full_loop_with_karras(self): - # sample = self.full_loop(use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.2925) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(use_karras_sigmas=True) - - def test_full_loop_with_v_prediction(self): - sample = self.full_loop(prediction_type="v_prediction") - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.1014) < 1e-3 - - def test_full_loop_with_karras_and_v_prediction(self): - # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.1966) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - - def test_fp16_support(self): - scheduler_class = self.scheduler_classes[0] - for order in [1, 2, 3]: - for solver_type in ["bh1", "bh2"]: - for prediction_type in ["epsilon", "sample", "v_prediction"]: - scheduler_config = self.get_scheduler_config( - thresholding=False, - dynamic_thresholding_ratio=0, - prediction_type=prediction_type, - solver_order=order, - solver_type=solver_type, - ) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter.astype(jnp.bfloat16) - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - for i, t in enumerate(state.timesteps): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state - # sample is casted to fp32 inside step and output should be fp32. - self.assertEqual(sample.dtype, jnp.float32) - - def test_full_loop_with_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - t_start_index = 8 - - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - # add noise - noise = self.dummy_noise_deter - timesteps_for_noise = state.timesteps[t_start_index :] - sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) - - for i, t in enumerate(timesteps_for_noise): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" - assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" - - -class FlaxUniPCMultistepScheduler1DTest(FlaxUniPCMultistepSchedulerTest): - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - width = 8 - - torch_sample = torch.rand((batch_size, num_channels, width)) - jax_sample= jnp.asarray(torch_sample) - return jax_sample - - @property - def dummy_noise_deter(self): - batch_size = 4 - num_channels = 3 - width = 8 - - num_elems = batch_size * num_channels * width - sample = torch.arange(num_elems).flip(-1) - sample = sample.reshape(num_channels, width, batch_size) - sample = sample / num_elems - sample = sample.permute(2, 0, 1) - - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - width = 8 - - num_elems = batch_size * num_channels * width - sample = torch.arange(num_elems) - sample = sample.reshape(num_channels, width, batch_size) - sample = sample / num_elems - sample = sample.permute(2, 0, 1) - jax_sample= jnp.asarray(sample) - return jax_sample - - def test_switch(self): - # make sure that iterating over schedulers with same config names gives same results - # for defaults - scheduler = FlaxUniPCMultistepScheduler(**self.get_scheduler_config()) - sample = self.full_loop(scheduler=scheduler) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - scheduler = FlaxDPMSolverMultistepScheduler.from_config(scheduler.config) - scheduler = FlaxUniPCMultistepScheduler.from_config(scheduler.config) - - sample = self.full_loop(scheduler=scheduler) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - def test_full_loop_with_karras(self): - # sample = self.full_loop(use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.2898) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(use_karras_sigmas=True) - - - def test_full_loop_with_v_prediction(self): - sample = self.full_loop(prediction_type="v_prediction") - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.1014) < 1e-3 - - def test_full_loop_with_karras_and_v_prediction(self): - # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.1944) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - - def test_full_loop_with_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - t_start_index = 8 - - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - # add noise - noise = self.dummy_noise_deter - timesteps_for_noise = state.timesteps[t_start_index :] - sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) - - for i, t in enumerate(timesteps_for_noise): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state - - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" - assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}" - - def test_beta_sigmas(self): - # self.check_over_configs(use_beta_sigmas=True) - with self.assertRaises(NotImplementedError): - self.full_loop(use_beta_sigmas=True) - - def test_exponential_sigmas(self): - #self.check_over_configs(use_exponential_sigmas=True) - with self.assertRaises(NotImplementedError): - self.full_loop(use_exponential_sigmas=True)