diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index dc3aa102be78..3017fc96a5e3 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -28,7 +28,7 @@ jobs: options: --shm-size "16gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: NVIDIA-SMI @@ -58,7 +58,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: benchmark_test_reports path: benchmarks/${{ env.BASE_PATH }} diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml index b1af44736730..f928e123aa8f 100644 --- a/.github/workflows/build_docker_images.yml +++ b/.github/workflows/build_docker_images.yml @@ -25,10 +25,10 @@ jobs: if: github.event_name == 'pull_request' steps: - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Check out code - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Find Changed Dockerfiles id: file_changes @@ -99,16 +99,16 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to Docker Hub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ env.REGISTRY }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Build and push - uses: docker/build-push-action@v3 + uses: docker/build-push-action@v6 with: no-cache: true context: ./docker/${{ matrix.image-name }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index f47645c1f659..8e8dc92cb57d 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -17,10 +17,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.10' diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 000000000000..5ba158b46fde --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,22 @@ +--- +name: CodeQL Security Analysis For Github Actions + +on: + push: + branches: ["main"] + workflow_dispatch: + # pull_request: + +jobs: + codeql: + name: CodeQL Analysis + uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1 + permissions: + security-events: write + packages: read + actions: read + contents: read + with: + languages: '["actions","python"]' + queries: 'security-extended,security-and-quality' + runner: 'ubuntu-latest' #optional if need custom runner diff --git a/.github/workflows/mirror_community_pipeline.yml b/.github/workflows/mirror_community_pipeline.yml index ab4ded973047..73cced7c1394 100644 --- a/.github/workflows/mirror_community_pipeline.yml +++ b/.github/workflows/mirror_community_pipeline.yml @@ -24,7 +24,6 @@ jobs: mirror_community_pipeline: env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }} - runs-on: ubuntu-22.04 steps: # Checkout to correct ref @@ -39,37 +38,41 @@ jobs: # If ref is 'refs/heads/main' => set 'main' # Else it must be a tag => set {tag} - name: Set checkout_ref and path_in_repo + env: + EVENT_NAME: ${{ github.event_name }} + EVENT_INPUT_REF: ${{ github.event.inputs.ref }} + GITHUB_REF: ${{ github.ref }} run: | - if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then - if [ -z "${{ github.event.inputs.ref }}" ]; then + if [ "$EVENT_NAME" == "workflow_dispatch" ]; then + if [ -z "$EVENT_INPUT_REF" ]; then echo "Error: Missing ref input" exit 1 - elif [ "${{ github.event.inputs.ref }}" == "main" ]; then + elif [ "$EVENT_INPUT_REF" == "main" ]; then echo "CHECKOUT_REF=refs/heads/main" >> $GITHUB_ENV echo "PATH_IN_REPO=main" >> $GITHUB_ENV else - echo "CHECKOUT_REF=refs/tags/${{ github.event.inputs.ref }}" >> $GITHUB_ENV - echo "PATH_IN_REPO=${{ github.event.inputs.ref }}" >> $GITHUB_ENV + echo "CHECKOUT_REF=refs/tags/$EVENT_INPUT_REF" >> $GITHUB_ENV + echo "PATH_IN_REPO=$EVENT_INPUT_REF" >> $GITHUB_ENV fi - elif [ "${{ github.ref }}" == "refs/heads/main" ]; then - echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV + elif [ "$GITHUB_REF" == "refs/heads/main" ]; then + echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV echo "PATH_IN_REPO=main" >> $GITHUB_ENV else # e.g. refs/tags/v0.28.1 -> v0.28.1 - echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV - echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV + echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV + echo "PATH_IN_REPO=$(echo $GITHUB_REF | sed 's/^refs\/tags\///')" >> $GITHUB_ENV fi - name: Print env vars run: | echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}" echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}" - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 with: ref: ${{ env.CHECKOUT_REF }} # Setup + install dependencies - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.10" - name: Install dependencies @@ -99,4 +102,4 @@ jobs: - name: Report failure status if: ${{ failure() }} run: | - pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure \ No newline at end of file + pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 8b7e57e91297..416d2af3fc2e 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -28,7 +28,7 @@ jobs: pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }} steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: Install dependencies @@ -44,7 +44,7 @@ jobs: - name: Pipeline Tests Artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: test-pipelines.json path: reports @@ -64,7 +64,7 @@ jobs: options: --shm-size "16gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: NVIDIA-SMI @@ -97,7 +97,7 @@ jobs: cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pipeline_${{ matrix.module }}_test_reports path: reports @@ -119,7 +119,7 @@ jobs: module: [models, schedulers, lora, others, single_file, examples] steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -167,7 +167,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_${{ matrix.module }}_cuda_test_reports path: reports @@ -184,7 +184,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -211,7 +211,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_compile_test_reports path: reports @@ -228,7 +228,7 @@ jobs: options: --shm-size "16gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: NVIDIA-SMI @@ -263,7 +263,7 @@ jobs: cat reports/tests_big_gpu_torch_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_cuda_big_gpu_test_reports path: reports @@ -280,7 +280,7 @@ jobs: shell: bash steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -321,7 +321,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_minimum_version_cuda_test_reports path: reports @@ -355,7 +355,7 @@ jobs: options: --shm-size "20gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: NVIDIA-SMI @@ -391,7 +391,7 @@ jobs: cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_cuda_${{ matrix.config.backend }}_reports path: reports @@ -408,7 +408,7 @@ jobs: options: --shm-size "20gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: NVIDIA-SMI @@ -441,7 +441,7 @@ jobs: cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_cuda_pipeline_level_quant_reports path: reports @@ -466,7 +466,7 @@ jobs: image: diffusers/diffusers-pytorch-cpu steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -474,7 +474,7 @@ jobs: run: mkdir -p combined_reports - name: Download all test reports - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: artifacts @@ -500,7 +500,7 @@ jobs: cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY - name: Upload consolidated report - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: consolidated_test_report path: ${{ env.CONSOLIDATED_REPORT_PATH }} @@ -514,7 +514,7 @@ jobs: # # steps: # - name: Checkout diffusers -# uses: actions/checkout@v3 +# uses: actions/checkout@v6 # with: # fetch-depth: 2 # @@ -554,7 +554,7 @@ jobs: # # - name: Test suite reports artifacts # if: ${{ always() }} -# uses: actions/upload-artifact@v4 +# uses: actions/upload-artifact@v6 # with: # name: torch_mps_test_reports # path: reports @@ -570,7 +570,7 @@ jobs: # # steps: # - name: Checkout diffusers -# uses: actions/checkout@v3 +# uses: actions/checkout@v6 # with: # fetch-depth: 2 # @@ -610,7 +610,7 @@ jobs: # # - name: Test suite reports artifacts # if: ${{ always() }} -# uses: actions/upload-artifact@v4 +# uses: actions/upload-artifact@v6 # with: # name: torch_mps_test_reports # path: reports diff --git a/.github/workflows/notify_slack_about_release.yml b/.github/workflows/notify_slack_about_release.yml index 612ad4e24503..6c0b96954e81 100644 --- a/.github/workflows/notify_slack_about_release.yml +++ b/.github/workflows/notify_slack_about_release.yml @@ -10,10 +10,10 @@ jobs: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: '3.8' diff --git a/.github/workflows/pr_dependency_test.yml b/.github/workflows/pr_dependency_test.yml index b914d1076190..dfc35c41066f 100644 --- a/.github/workflows/pr_dependency_test.yml +++ b/.github/workflows/pr_dependency_test.yml @@ -18,9 +18,9 @@ jobs: check_dependencies: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.8" - name: Install dependencies diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml index 13c228621f5c..89b502d364ec 100644 --- a/.github/workflows/pr_modular_tests.yml +++ b/.github/workflows/pr_modular_tests.yml @@ -1,3 +1,4 @@ + name: Fast PR tests for Modular on: @@ -35,9 +36,9 @@ jobs: check_code_quality: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.10" - name: Install dependencies @@ -55,9 +56,9 @@ jobs: needs: check_code_quality runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.10" - name: Install dependencies @@ -74,26 +75,34 @@ jobs: if: ${{ failure() }} run: | echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY + check_auto_docs: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v6 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.10" + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[quality] + - name: Check auto docs + run: make modular-autodoctrings + - name: Check if failure + if: ${{ failure() }} + run: | + echo "Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`." >> $GITHUB_STEP_SUMMARY run_fast_tests: - needs: [check_code_quality, check_repository_consistency] - strategy: - fail-fast: false - matrix: - config: - - name: Fast PyTorch Modular Pipeline CPU tests - framework: pytorch_pipelines - runner: aws-highmemory-32-plus - image: diffusers/diffusers-pytorch-cpu - report: torch_cpu_modular_pipelines - - name: ${{ matrix.config.name }} + needs: [check_code_quality, check_repository_consistency, check_auto_docs] + name: Fast PyTorch Modular Pipeline CPU tests runs-on: - group: ${{ matrix.config.runner }} + group: aws-highmemory-32-plus container: - image: ${{ matrix.config.image }} + image: diffusers/diffusers-pytorch-cpu options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ defaults: @@ -102,7 +111,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -118,22 +127,19 @@ jobs: python utils/print_env.py - name: Run fast PyTorch Pipeline CPU tests - if: ${{ matrix.config.framework == 'pytorch_pipelines' }} run: | pytest -n 8 --max-worker-restart=0 --dist=loadfile \ -k "not Flax and not Onnx" \ - --make-reports=tests_${{ matrix.config.report }} \ + --make-reports=tests_torch_cpu_modular_pipelines \ tests/modular_pipelines - name: Failure short reports if: ${{ failure() }} - run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt + run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: - name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports + name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports path: reports - - diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml index 83b2ab4edbf6..a02a40709fcc 100644 --- a/.github/workflows/pr_test_fetcher.yml +++ b/.github/workflows/pr_test_fetcher.yml @@ -28,7 +28,7 @@ jobs: test_map: ${{ steps.set_matrix.outputs.test_map }} steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Install dependencies @@ -42,7 +42,7 @@ jobs: run: | python utils/tests_fetcher.py | tee test_preparation.txt - name: Report fetched tests - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v6 with: name: test_fetched path: test_preparation.txt @@ -83,7 +83,7 @@ jobs: shell: bash steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -109,7 +109,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v6 with: name: ${{ matrix.modules }}_test_reports path: reports @@ -138,7 +138,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -164,7 +164,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pr_${{ matrix.config.report }}_test_reports path: reports diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 674e62ff443a..c0dfa89e776d 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -31,9 +31,9 @@ jobs: check_code_quality: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.8" - name: Install dependencies @@ -51,9 +51,9 @@ jobs: needs: check_code_quality runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.8" - name: Install dependencies @@ -108,7 +108,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -153,7 +153,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports path: reports @@ -185,7 +185,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -211,7 +211,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pr_${{ matrix.config.report }}_test_reports path: reports @@ -236,7 +236,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -273,7 +273,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pr_main_test_reports path: reports diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index 468979d379c1..dd20bbe93250 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -32,9 +32,9 @@ jobs: check_code_quality: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.8" - name: Install dependencies @@ -52,9 +52,9 @@ jobs: needs: check_code_quality runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.8" - name: Install dependencies @@ -83,7 +83,7 @@ jobs: pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }} steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: Install dependencies @@ -100,7 +100,7 @@ jobs: echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT - name: Pipeline Tests Artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: test-pipelines.json path: reports @@ -120,7 +120,7 @@ jobs: options: --shm-size "16gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -170,7 +170,7 @@ jobs: cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pipeline_${{ matrix.module }}_test_reports path: reports @@ -193,7 +193,7 @@ jobs: module: [models, schedulers, lora, others] steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -239,7 +239,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_cuda_test_reports_${{ matrix.module }} path: reports @@ -255,7 +255,7 @@ jobs: options: --gpus all --shm-size "16gb" --ipc host steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -287,7 +287,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: examples_test_reports path: reports diff --git a/.github/workflows/pr_torch_dependency_test.yml b/.github/workflows/pr_torch_dependency_test.yml index 4b6160ff71e2..4a7e5ab37e47 100644 --- a/.github/workflows/pr_torch_dependency_test.yml +++ b/.github/workflows/pr_torch_dependency_test.yml @@ -18,9 +18,9 @@ jobs: check_torch_dependencies: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.8" - name: Install dependencies diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 7b1c441d3dc0..4456f18c95bc 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -29,7 +29,7 @@ jobs: pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }} steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: Install dependencies @@ -46,7 +46,7 @@ jobs: echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT - name: Pipeline Tests Artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: test-pipelines.json path: reports @@ -66,7 +66,7 @@ jobs: options: --shm-size "16gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: NVIDIA-SMI @@ -98,7 +98,7 @@ jobs: cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pipeline_${{ matrix.module }}_test_reports path: reports @@ -120,7 +120,7 @@ jobs: module: [models, schedulers, lora, others, single_file] steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -155,7 +155,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_cuda_test_reports_${{ matrix.module }} path: reports @@ -172,7 +172,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -199,7 +199,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_compile_test_reports path: reports @@ -216,7 +216,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -240,7 +240,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_xformers_test_reports path: reports @@ -256,7 +256,7 @@ jobs: options: --gpus all --shm-size "16gb" --ipc host steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -286,7 +286,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: examples_test_reports path: reports diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 38cbffaa6315..fe6f6a265e89 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -54,7 +54,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -88,7 +88,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pr_${{ matrix.config.report }}_test_reports path: reports diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index 2d6feb592815..cc16d5f82cd0 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -65,7 +65,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pr_torch_mps_test_reports path: reports diff --git a/.github/workflows/pypi_publish.yaml b/.github/workflows/pypi_publish.yaml index dc36b6b024c5..214b996f5381 100644 --- a/.github/workflows/pypi_publish.yaml +++ b/.github/workflows/pypi_publish.yaml @@ -15,10 +15,10 @@ jobs: latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }} steps: - name: Checkout Repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: '3.8' @@ -40,12 +40,12 @@ jobs: steps: - name: Checkout Repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }} - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.8" diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml index efdd6ea2b651..f667d715090d 100644 --- a/.github/workflows/release_tests_fast.yml +++ b/.github/workflows/release_tests_fast.yml @@ -27,7 +27,7 @@ jobs: pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }} steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: Install dependencies @@ -44,7 +44,7 @@ jobs: echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT - name: Pipeline Tests Artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: test-pipelines.json path: reports @@ -64,7 +64,7 @@ jobs: options: --shm-size "16gb" --ipc host --gpus all steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 - name: NVIDIA-SMI @@ -94,7 +94,7 @@ jobs: cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: pipeline_${{ matrix.module }}_test_reports path: reports @@ -116,7 +116,7 @@ jobs: module: [models, schedulers, lora, others, single_file] steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -149,7 +149,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_cuda_${{ matrix.module }}_test_reports path: reports @@ -166,7 +166,7 @@ jobs: shell: bash steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -205,7 +205,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_minimum_version_cuda_test_reports path: reports @@ -222,7 +222,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -247,7 +247,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_compile_test_reports path: reports @@ -264,7 +264,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -288,7 +288,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: torch_xformers_test_reports path: reports @@ -305,7 +305,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 @@ -336,7 +336,7 @@ jobs: - name: Test suite reports artifacts if: ${{ always() }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: examples_test_reports path: reports diff --git a/.github/workflows/run_tests_from_a_pr.yml b/.github/workflows/run_tests_from_a_pr.yml index fa8c579dd768..3e5462f5100f 100644 --- a/.github/workflows/run_tests_from_a_pr.yml +++ b/.github/workflows/run_tests_from_a_pr.yml @@ -57,7 +57,7 @@ jobs: shell: bash -e {0} - name: Checkout PR branch - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: ref: refs/pull/${{ inputs.pr_number }}/head diff --git a/.github/workflows/ssh-pr-runner.yml b/.github/workflows/ssh-pr-runner.yml index 49fa9c0ad24d..27246fb61348 100644 --- a/.github/workflows/ssh-pr-runner.yml +++ b/.github/workflows/ssh-pr-runner.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 diff --git a/.github/workflows/ssh-runner.yml b/.github/workflows/ssh-runner.yml index 917eb5b1b31a..4fbfad3dc7c6 100644 --- a/.github/workflows/ssh-runner.yml +++ b/.github/workflows/ssh-runner.yml @@ -35,7 +35,7 @@ jobs: steps: - name: Checkout diffusers - uses: actions/checkout@v3 + uses: actions/checkout@v6 with: fetch-depth: 2 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 27450ed4c7f2..b0a90e278550 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -15,10 +15,10 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 - name: Setup Python - uses: actions/setup-python@v1 + uses: actions/setup-python@v6 with: python-version: 3.8 diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index 4743dc352455..65334e086c83 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 - name: Secret Scanning diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index 6d2f2fc8dd9a..87ea38a5bbac 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: typos-action - uses: crate-ci/typos@v1.12.4 + uses: crate-ci/typos@v1.42.1 diff --git a/.github/workflows/update_metadata.yml b/.github/workflows/update_metadata.yml index 92aea0369ba8..6e608883c13a 100644 --- a/.github/workflows/update_metadata.yml +++ b/.github/workflows/update_metadata.yml @@ -15,7 +15,7 @@ jobs: shell: bash -l {0} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v6 - name: Setup environment run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index ec18df882641..000000000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,506 +0,0 @@ - - -# How to contribute to Diffusers 🧨 - -We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it! - -Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. Join us on Discord - -Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility. - -We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered. - -## Overview - -You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to -the core library. - -In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community. - -* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR). -* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose). -* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues). -* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). -* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source). -* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples). -* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples). -* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22). -* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md). - -As said before, **all contributions are valuable to the community**. -In the following, we will explain each contribution a bit more in detail. - -For all contributions 4-9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr). - -### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord - -Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to): -- Reports of training or inference experiments in an attempt to share knowledge -- Presentation of personal projects -- Questions to non-official training examples -- Project proposals -- General feedback -- Paper summaries -- Asking for help on personal projects that build on top of the Diffusers library -- General questions -- Ethical questions regarding diffusion models -- ... - -Every question that is asked on the forum or on Discord actively encourages the community to publicly -share knowledge and might very well help a beginner in the future who has the same question you're -having. Please do pose any questions you might have. -In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from. - -**Please** keep in mind that the more effort you put into asking or answering a question, the higher -the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database. -In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. - -**NOTE about channels**: -[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago. -In addition, questions and answers posted in the forum can easily be linked to. -In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication. -While it will most likely take less time for you to get an answer to your question on Discord, your -question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers. - -### 2. Opening new issues on the GitHub issues tab - -The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of -the problems they encounter. So thank you for reporting an issue. - -Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design. - -In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR). - -**Please consider the following guidelines when opening a new issue**: -- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues). -- Please never report a new issue on another (related) issue. If another issue is highly related, please -open a new issue nevertheless and link to the related issue. -- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English. -- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version. -- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues. - -New issues usually include the following. - -#### 2.1. Reproducible, minimal bug reports - -A bug report should always have a reproducible code snippet and be as minimal and concise as possible. -This means in more detail: -- Narrow the bug down as much as you can, **do not just dump your whole code file**. -- Format your code. -- Do not include any external libraries except for Diffusers depending on them. -- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue. -- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it. -- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell. -- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible. - -For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section. - -You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml). - -#### 2.2. Feature requests - -A world-class feature request addresses the following points: - -1. Motivation first: -* Is it related to a problem/frustration with the library? If so, please explain -why. Providing a code snippet that demonstrates the problem is best. -* Is it related to something you would need for a project? We'd love to hear -about it! -* Is it something you worked on and think could benefit the community? -Awesome! Tell us what problem it solved for you. -2. Write a *full paragraph* describing the feature; -3. Provide a **code snippet** that demonstrates its future use; -4. In case this is related to a paper, please attach a link; -5. Attach any additional information (drawings, screenshots, etc.) you think may help. - -You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=). - -#### 2.3 Feedback - -Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed. -If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions. - -You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=). - -#### 2.4 Technical questions - -Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on -why this part of the code is difficult to understand. - -You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml). - -#### 2.5 Proposal to add a new model, scheduler, or pipeline - -If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information: - -* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release. -* Link to any of its open-source implementation. -* Link to the model weights if they are available. - -If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget -to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it. - -You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml). - -### 3. Answering issues on the GitHub issues tab - -Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct. -Some tips to give a high-quality answer to an issue: -- Be as concise and minimal as possible. -- Stay on topic. An answer to the issue should concern the issue and only the issue. -- Provide links to code, papers, or other sources that prove or encourage your point. -- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet. - -Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great -help to the maintainers if you can answer such issues, encouraging the author of the issue to be -more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR). - -If you have verified that the issued bug report is correct and requires a correction in the source code, -please have a look at the next sections. - -For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section. - -### 4. Fixing a "Good first issue" - -*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already -explains how a potential solution should look so that it is easier to fix. -If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios: -- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it. -- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR. -- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR. - - -### 5. Contribute to the documentation - -A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly -valuable contribution**. - -Contributing to the library can have many forms: - -- Correcting spelling or grammatical errors. -- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we are very happy if you take some time to correct it. -- Correct the shape or dimensions of a docstring input or output tensor. -- Clarify documentation that is hard to understand or incorrect. -- Update outdated code examples. -- Translating the documentation to another language. - -Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source). - -Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally. - - -### 6. Contribute a community pipeline - -[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user. -Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview). -We support two types of pipelines: - -- Official Pipelines -- Community Pipelines - -Both official and community pipelines follow the same design and consist of the same type of components. - -Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code -resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). -In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested. -They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution. - -The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all -possible ways diffusion models can be used for inference, but some of them may be of interest to the community. -Officially released diffusion pipelines, -such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures -high quality of maintenance, no backward-breaking code changes, and testing. -More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library. - -To add a community pipeline, one should add a .py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline. - -An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400). - -Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors. - -Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the -core package. - -### 7. Contribute to training examples - -Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples). - -We support two types of training examples: - -- Official training examples -- Research training examples - -Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders. -The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community. -This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models. -If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author. - -Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the -training examples, it is required to clone the repository: - -```bash -git clone https://github.com/huggingface/diffusers -``` - -as well as to install all additional dependencies required for training: - -```bash -cd diffusers -pip install -r examples//requirements.txt -``` - -Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt). - -Training examples of the Diffusers library should adhere to the following philosophy: -- All the code necessary to run the examples should be found in a single Python file. -- One should be able to run the example from the command line with `python .py --args`. -- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials. - -To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like. -We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated -with Diffusers. -Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include: -- An example command on how to run the example script as shown [here e.g.](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch). -- A link to some training results (logs, models, ...) that show what the user can expect as shown [here e.g.](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5). -- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations). - -If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples. - -### 8. Fixing a "Good second issue" - -*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are -usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). -The issue description usually gives less guidance on how to fix the issue and requires -a decent understanding of the library by the interested contributor. -If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR. -Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged. - -### 9. Adding pipelines, models, schedulers - -Pipelines, models, and schedulers are the most important pieces of the Diffusers library. -They provide easy access to state-of-the-art diffusion technologies and thus allow the community to -build powerful generative AI applications. - -By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem. - -Diffusers has a couple of open feature requests for all three components - feel free to gloss over them -if you don't know yet what specific component you would like to add: -- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) -- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) - -Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) a read to better understand the design of any of the three components. Please be aware that -we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy -as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please -open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design -pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us. - -Please make sure to add links to the original codebase/paper to the PR and ideally also ping the -original author directly on the PR so that they can follow the progress and potentially help with questions. - -If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help. - -## How to write a good issue - -**The better your issue is written, the higher the chances that it will be quickly resolved.** - -1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose). -2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers". -3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data. -4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets. -5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better. -6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information. -7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library. - -## How to write a good PR - -1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged. -2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once. -3. If helpful, try to add a code snippet that displays an example of how your addition can be used. -4. The title of your pull request should be a summary of its contribution. -5. If your pull request addresses an issue, please mention the issue number in -the pull request description to make sure they are linked (and people -consulting the issue know you are working on it); -6. To indicate a work in progress please prefix the title with `[WIP]`. These -are useful to avoid duplicated work, and to differentiate it from PRs ready -to be merged; -7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue). -8. Make sure existing tests pass; -9. Add high-coverage tests. No quality testing = no merge. -- If you are adding new `@slow` tests, make sure they pass using -`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`. -CircleCI does not run the slow tests, but GitHub Actions does every night! -10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example. -11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like -[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files. -If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images -to this dataset. - -## How to open a PR - -Before writing code, we strongly advise you to search through the existing PRs or -issues to make sure that nobody is already working on the same thing. If you are -unsure, it is always a good idea to open an issue to get some feedback. - -You will need basic `git` proficiency to be able to contribute to -🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest -manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro -Git](https://git-scm.com/book/en/v2) is a very good reference. - -Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)): - -1. Fork the [repository](https://github.com/huggingface/diffusers) by -clicking on the 'Fork' button on the repository's page. This creates a copy of the code -under your GitHub user account. - -2. Clone your fork to your local disk, and add the base repository as a remote: - - ```bash - $ git clone git@github.com:/diffusers.git - $ cd diffusers - $ git remote add upstream https://github.com/huggingface/diffusers.git - ``` - -3. Create a new branch to hold your development changes: - - ```bash - $ git checkout -b a-descriptive-name-for-my-changes - ``` - -**Do not** work on the `main` branch. - -4. Set up a development environment by running the following command in a virtual environment: - - ```bash - $ pip install -e ".[dev]" - ``` - -If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the -library. - -5. Develop the features on your branch. - -As you work on the features, you should make sure that the test suite -passes. You should run the tests impacted by your changes like this: - - ```bash - $ pytest tests/.py - ``` - -Before you run the tests, please make sure you install the dependencies required for testing. You can do so -with this command: - - ```bash - $ pip install -e ".[test]" - ``` - -You can also run the full test suite with the following command, but it takes -a beefy machine to produce a result in a decent amount of time now that -Diffusers has grown a lot. Here is the command for it: - - ```bash - $ make test - ``` - -🧨 Diffusers relies on `ruff` and `isort` to format its source code -consistently. After you make changes, apply automatic style corrections and code verifications -that can't be automated in one go with: - - ```bash - $ make style - ``` - -🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality -control runs in CI, however, you can also run the same checks with: - - ```bash - $ make quality - ``` - -Once you're happy with your changes, add changed files using `git add` and -make a commit with `git commit` to record your changes locally: - - ```bash - $ git add modified_file.py - $ git commit -m "A descriptive message about your changes." - ``` - -It is a good idea to sync your copy of the code with the original -repository regularly. This way you can quickly account for changes: - - ```bash - $ git pull upstream main - ``` - -Push the changes to your account using: - - ```bash - $ git push -u origin a-descriptive-name-for-my-changes - ``` - -6. Once you are satisfied, go to the -webpage of your fork on GitHub. Click on 'Pull request' to send your changes -to the project maintainers for review. - -7. It's ok if maintainers ask you for changes. It happens to core contributors -too! So everyone can see the changes in the Pull request, work in your local -branch and push the changes to your fork. They will automatically appear in -the pull request. - -### Tests - -An extensive test suite is included to test the library behavior and several examples. Library tests can be found in -the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests). - -We like `pytest` and `pytest-xdist` because it's faster. From the root of the -repository, here's how to run tests with `pytest` for the library: - -```bash -$ python -m pytest -n auto --dist=loadfile -s -v ./tests/ -``` - -In fact, that's how `make test` is implemented! - -You can specify a smaller set of tests in order to test only the feature -you're working on. - -By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to -`yes` to run them. This will download many gigabytes of models — make sure you -have enough disk space and a good Internet connection, or a lot of patience! - -```bash -$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/ -``` - -`unittest` is fully supported, here's how to run tests with it: - -```bash -$ python -m unittest discover -s tests -t . -v -$ python -m unittest discover -s examples -t examples -v -``` - -### Syncing forked main with upstream (HuggingFace) main - -To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs, -when syncing the main branch of a forked repository, please, follow these steps: -1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main. -2. If a PR is absolutely necessary, use the following steps after checking out your branch: -```bash -$ git checkout -b your-branch-for-syncing -$ git pull --squash --no-commit upstream main -$ git commit -m '' -$ git push --set-upstream origin your-branch-for-syncing -``` - -### Style guide - -For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 120000 index 000000000000..53de38ca21e3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1 @@ +docs/source/en/conceptual/contribution.md \ No newline at end of file diff --git a/Makefile b/Makefile index 9af2e8b1a5c9..b90ff82ab268 100644 --- a/Makefile +++ b/Makefile @@ -70,6 +70,10 @@ fix-copies: python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite +# Auto docstrings in modular blocks +modular-autodoctrings: + python utils/modular_auto_docstring.py + # Run tests for the library test: diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index a700d1db72bc..b11d8a491168 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -1,8 +1,8 @@ -FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04 +FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 LABEL maintainer="Hugging Face" LABEL repository="diffusers" -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.10 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get -y update \ @@ -32,10 +32,17 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV} ENV PATH="$VIRTUAL_ENV/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# Install torch, torchvision, and torchaudio together to ensure compatibility RUN uv pip install --no-cache-dir \ torch \ torchvision \ - torchaudio + torchaudio \ + --index-url https://download.pytorch.org/whl/cu129 + +# Install compatible versions of numba/llvmlite for Python 3.10+ +RUN uv pip install --no-cache-dir \ + "llvmlite>=0.40.0" \ + "numba>=0.57.0" RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]" diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile index eae7eaf4faf1..33cae319b93c 100644 --- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile +++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile @@ -1,8 +1,8 @@ -FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04 +FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04 LABEL maintainer="Hugging Face" LABEL repository="diffusers" -ARG PYTHON_VERSION=3.12 +ARG PYTHON_VERSION=3.10 ENV DEBIAN_FRONTEND=noninteractive RUN apt-get -y update \ @@ -32,10 +32,17 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV} ENV PATH="$VIRTUAL_ENV/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) +# Install torch, torchvision, and torchaudio together to ensure compatibility RUN uv pip install --no-cache-dir \ torch \ torchvision \ - torchaudio + torchaudio \ + --index-url https://download.pytorch.org/whl/cu129 + +# Install compatible versions of numba/llvmlite for Python 3.10+ +RUN uv pip install --no-cache-dir \ + "llvmlite>=0.40.0" \ + "numba>=0.57.0" RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]" diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 748389f373aa..64a4222845b0 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -54,6 +54,8 @@ title: Batch inference - local: training/distributed_inference title: Distributed inference + - local: hybrid_inference/overview + title: Remote inference title: Inference - isExpanded: false sections: @@ -88,17 +90,6 @@ title: FreeU title: Community optimizations title: Inference optimization -- isExpanded: false - sections: - - local: hybrid_inference/overview - title: Overview - - local: hybrid_inference/vae_decode - title: VAE Decode - - local: hybrid_inference/vae_encode - title: VAE Encode - - local: hybrid_inference/api_reference - title: API Reference - title: Hybrid Inference - isExpanded: false sections: - local: modular_diffusers/overview @@ -123,6 +114,8 @@ title: Guiders - local: modular_diffusers/custom_blocks title: Building Custom Blocks + - local: modular_diffusers/mellon + title: Using Custom Blocks with Mellon title: Modular Diffusers - isExpanded: false sections: @@ -270,6 +263,8 @@ title: Outputs - local: api/quantization title: Quantization + - local: hybrid_inference/api_reference + title: Remote inference - local: api/parallel title: Parallel inference title: Main Classes @@ -353,6 +348,8 @@ title: Flux2Transformer2DModel - local: api/models/flux_transformer title: FluxTransformer2DModel + - local: api/models/glm_image_transformer2d + title: GlmImageTransformer2DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d @@ -365,6 +362,10 @@ title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel + - local: api/models/longcat_image_transformer2d + title: LongCatImageTransformer2DModel + - local: api/models/ltx2_video_transformer3d + title: LTX2VideoTransformer3DModel - local: api/models/ltx_video_transformer3d title: LTXVideoTransformer3DModel - local: api/models/lumina2_transformer2d @@ -401,6 +402,8 @@ title: WanAnimateTransformer3DModel - local: api/models/wan_transformer_3d title: WanTransformer3DModel + - local: api/models/z_image_transformer2d + title: ZImageTransformer2DModel title: Transformers - sections: - local: api/models/stable_cascade_unet @@ -439,6 +442,10 @@ title: AutoencoderKLHunyuanVideo - local: api/models/autoencoder_kl_hunyuan_video15 title: AutoencoderKLHunyuanVideo15 + - local: api/models/autoencoderkl_audio_ltx_2 + title: AutoencoderKLLTX2Audio + - local: api/models/autoencoderkl_ltx_2 + title: AutoencoderKLLTX2Video - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_magvit @@ -491,6 +498,8 @@ title: Bria 3.2 - local: api/pipelines/bria_fibo title: Bria Fibo + - local: api/pipelines/bria_fibo_edit + title: Bria Fibo Edit - local: api/pipelines/chroma title: Chroma - local: api/pipelines/cogview3 @@ -537,6 +546,8 @@ title: Flux2 - local: api/pipelines/control_flux_inpaint title: FluxControlInpaint + - local: api/pipelines/glm_image + title: GLM-Image - local: api/pipelines/hidream title: HiDream-I1 - local: api/pipelines/hunyuandit @@ -551,6 +562,8 @@ title: Kandinsky 2.2 - local: api/pipelines/kandinsky3 title: Kandinsky 3 + - local: api/pipelines/kandinsky5_image + title: Kandinsky 5.0 Image - local: api/pipelines/kolors title: Kolors - local: api/pipelines/latent_consistency_models @@ -559,6 +572,8 @@ title: Latent Diffusion - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/longcat_image + title: LongCat-Image - local: api/pipelines/lumina2 title: Lumina 2.0 - local: api/pipelines/lumina @@ -646,6 +661,8 @@ title: VisualCloze - local: api/pipelines/wuerstchen title: Wuerstchen + - local: api/pipelines/z_image + title: Z-Image title: Image - sections: - local: api/pipelines/allegro @@ -664,12 +681,12 @@ title: HunyuanVideo1.5 - local: api/pipelines/i2vgenxl title: I2VGen-XL - - local: api/pipelines/kandinsky5_image - title: Kandinsky 5.0 Image - local: api/pipelines/kandinsky5_video title: Kandinsky 5.0 Video - local: api/pipelines/latte title: Latte + - local: api/pipelines/ltx2 + title: LTX-2 - local: api/pipelines/ltx_video title: LTXVideo - local: api/pipelines/mochi diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index 9ba474208551..6a2d74892cfa 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -29,8 +29,14 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate [[autodoc]] apply_faster_cache -### FirstBlockCacheConfig +## FirstBlockCacheConfig [[autodoc]] FirstBlockCacheConfig [[autodoc]] apply_first_block_cache + +### TaylorSeerCacheConfig + +[[autodoc]] TaylorSeerCacheConfig + +[[autodoc]] apply_taylorseer_cache diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 7911bc2b2332..bbae6a9020af 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -33,6 +33,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen). - [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage). - [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2). +- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2). - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. > [!TIP] @@ -62,6 +63,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin +## LTX2LoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin + ## CogVideoXLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin diff --git a/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md new file mode 100644 index 000000000000..d0024474e9e0 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_audio_ltx_2.md @@ -0,0 +1,29 @@ + + +# AutoencoderKLLTX2Audio + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTX2Audio + +vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTX2Audio + +[[autodoc]] AutoencoderKLLTX2Audio + - encode + - decode + - all \ No newline at end of file diff --git a/docs/source/en/api/models/autoencoderkl_ltx_2.md b/docs/source/en/api/models/autoencoderkl_ltx_2.md new file mode 100644 index 000000000000..1dbf516c017a --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_ltx_2.md @@ -0,0 +1,29 @@ + + +# AutoencoderKLLTX2Video + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTX2Video + +vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTX2Video + +[[autodoc]] AutoencoderKLLTX2Video + - decode + - encode + - all diff --git a/docs/source/en/api/models/controlnet.md b/docs/source/en/api/models/controlnet.md index f56b7383a0d7..0821d63fd152 100644 --- a/docs/source/en/api/models/controlnet.md +++ b/docs/source/en/api/models/controlnet.md @@ -33,6 +33,21 @@ url = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/m pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) ``` +## Loading from Control LoRA + +Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs. + +```py +from diffusers import ControlNetModel, UNet2DConditionModel + +lora_id = "stabilityai/control-lora" +lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" + +unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.bfloat16).to("cuda") +controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16) +controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config) +``` + ## ControlNetModel [[autodoc]] ControlNetModel diff --git a/docs/source/en/api/models/controlnet_flux.md b/docs/source/en/api/models/controlnet_flux.md index 6b230d90fba3..ec0370c19e06 100644 --- a/docs/source/en/api/models/controlnet_flux.md +++ b/docs/source/en/api/models/controlnet_flux.md @@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co ## FluxControlNetOutput -[[autodoc]] models.controlnet_flux.FluxControlNetOutput \ No newline at end of file +[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput \ No newline at end of file diff --git a/docs/source/en/api/models/controlnet_sparsectrl.md b/docs/source/en/api/models/controlnet_sparsectrl.md index b9e81dc57eeb..0aa9848d0d2b 100644 --- a/docs/source/en/api/models/controlnet_sparsectrl.md +++ b/docs/source/en/api/models/controlnet_sparsectrl.md @@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr ## SparseControlNetOutput -[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput +[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput diff --git a/docs/source/en/api/models/glm_image_transformer2d.md b/docs/source/en/api/models/glm_image_transformer2d.md new file mode 100644 index 000000000000..7a18d1050075 --- /dev/null +++ b/docs/source/en/api/models/glm_image_transformer2d.md @@ -0,0 +1,18 @@ + + +# GlmImageTransformer2DModel + +A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO). + +## GlmImageTransformer2DModel + +[[autodoc]] GlmImageTransformer2DModel diff --git a/docs/source/en/api/models/longcat_image_transformer2d.md b/docs/source/en/api/models/longcat_image_transformer2d.md new file mode 100644 index 000000000000..f40b2583e68b --- /dev/null +++ b/docs/source/en/api/models/longcat_image_transformer2d.md @@ -0,0 +1,25 @@ + + +# LongCatImageTransformer2DModel + +The model can be loaded with the following code snippet. + +```python +from diffusers import LongCatImageTransformer2DModel + +transformer = LongCatImageTransformer2DModel.from_pretrained("meituan-longcat/LongCat-Image ", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## LongCatImageTransformer2DModel + +[[autodoc]] LongCatImageTransformer2DModel \ No newline at end of file diff --git a/docs/source/en/api/models/ltx2_video_transformer3d.md b/docs/source/en/api/models/ltx2_video_transformer3d.md new file mode 100644 index 000000000000..9faab8695468 --- /dev/null +++ b/docs/source/en/api/models/ltx2_video_transformer3d.md @@ -0,0 +1,26 @@ + + +# LTX2VideoTransformer3DModel + +A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import LTX2VideoTransformer3DModel + +transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## LTX2VideoTransformer3DModel + +[[autodoc]] LTX2VideoTransformer3DModel diff --git a/docs/source/en/api/models/z_image_transformer2d.md b/docs/source/en/api/models/z_image_transformer2d.md new file mode 100644 index 000000000000..2ecb9851febd --- /dev/null +++ b/docs/source/en/api/models/z_image_transformer2d.md @@ -0,0 +1,19 @@ + + +# ZImageTransformer2DModel + +A Transformer model for image-like data from [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo). + +## ZImageTransformer2DModel + +[[autodoc]] ZImageTransformer2DModel \ No newline at end of file diff --git a/docs/source/en/api/pipelines/bria_fibo_edit.md b/docs/source/en/api/pipelines/bria_fibo_edit.md new file mode 100644 index 000000000000..b46dd78cdb90 --- /dev/null +++ b/docs/source/en/api/pipelines/bria_fibo_edit.md @@ -0,0 +1,33 @@ + + +# Bria Fibo Edit + +Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows. +Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments. +Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality + +## Usage +_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._ + +Use the command below to log in: + +```bash +hf auth login +``` + + +## BriaFiboEditPipeline + +[[autodoc]] BriaFiboEditPipeline + - all + - __call__ \ No newline at end of file diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md index cc52ffa09a6d..2b3b50c25e80 100644 --- a/docs/source/en/api/pipelines/chroma.md +++ b/docs/source/en/api/pipelines/chroma.md @@ -99,3 +99,9 @@ image.save("chroma-single-file.png") [[autodoc]] ChromaImg2ImgPipeline - all - __call__ + +## ChromaInpaintPipeline + +[[autodoc]] ChromaInpaintPipeline + - all + - __call__ diff --git a/docs/source/en/api/pipelines/chronoedit.md b/docs/source/en/api/pipelines/chronoedit.md index 48e70ab9e55e..5e7057f9ccb8 100644 --- a/docs/source/en/api/pipelines/chronoedit.md +++ b/docs/source/en/api/pipelines/chronoedit.md @@ -30,6 +30,10 @@ The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face. +Available Models/LoRAs: +- [nvidia/ChronoEdit-14B-Diffusers](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers) +- [nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora) +- [nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora) ### Image Editing @@ -100,6 +104,7 @@ Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.pn import torch import numpy as np from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline +from diffusers.schedulers import UniPCMultistepScheduler from diffusers.utils import export_to_video, load_image from transformers import CLIPVisionModel from PIL import Image @@ -109,9 +114,8 @@ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encod vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16) -lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors") -pipe.load_lora_weights(lora_path) -pipe.fuse_lora(lora_scale=1.0) +pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill") +pipe.fuse_lora(adapter_names=["distill"], lora_scale=1.0) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0) pipe.to("cuda") @@ -145,6 +149,57 @@ export_to_video(output, "output.mp4", fps=16) Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png") ``` +### Inference with Multiple LoRAs + +```py +import torch +import numpy as np +from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_image +from transformers import CLIPVisionModel +from PIL import Image + +model_id = "nvidia/ChronoEdit-14B-Diffusers" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) +pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16) +pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora", weight_name="paintbrush_lora_diffusers.safetensors", adapter_name="paintbrush") +pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill") +pipe.fuse_lora(adapter_names=["paintbrush", "distill"], lora_scale=1.0) +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0) +pipe.to("cuda") + +image = load_image( + "https://raw.githubusercontent.com/nv-tlabs/ChronoEdit/refs/heads/main/assets/images/input_paintbrush.png" +) +max_area = 720 * 1280 +aspect_ratio = image.height / image.width +mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] +height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value +width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value +print("width", width, "height", height) +image = image.resize((width, height)) +prompt = ( + "Turn the pencil sketch in the image into an actual object that is consistent with the image’s content. The user wants to change the sketch to a crown and a hat." +) + +output = pipe( + image=image, + prompt=prompt, + height=height, + width=width, + num_frames=5, + num_inference_steps=8, + guidance_scale=1.0, + enable_temporal_reasoning=False, + num_temporal_reasoning_steps=0, +).frames[0] +export_to_video(output, "output.mp4", fps=16) +Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output_1.png") +``` + ## ChronoEditPipeline [[autodoc]] ChronoEditPipeline diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md index fb9453480e74..60ecce660303 100644 --- a/docs/source/en/api/pipelines/cosmos.md +++ b/docs/source/en/api/pipelines/cosmos.md @@ -70,6 +70,12 @@ output.save("output.png") - all - __call__ +## Cosmos2_5_PredictBasePipeline + +[[autodoc]] Cosmos2_5_PredictBasePipeline + - all + - __call__ + ## CosmosPipelineOutput [[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput diff --git a/docs/source/en/api/pipelines/diffedit.md b/docs/source/en/api/pipelines/diffedit.md index 9734ca2eabc3..670b7bb4fca0 100644 --- a/docs/source/en/api/pipelines/diffedit.md +++ b/docs/source/en/api/pipelines/diffedit.md @@ -21,7 +21,7 @@ The abstract from the paper is: *Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.* -The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html). +The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html). This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️ diff --git a/docs/source/en/api/pipelines/flux2.md b/docs/source/en/api/pipelines/flux2.md index 393e0d03c341..4ace2f3b3aa0 100644 --- a/docs/source/en/api/pipelines/flux2.md +++ b/docs/source/en/api/pipelines/flux2.md @@ -35,5 +35,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a ## Flux2Pipeline [[autodoc]] Flux2Pipeline + - all + - __call__ + +## Flux2KleinPipeline + +[[autodoc]] Flux2KleinPipeline - all - __call__ \ No newline at end of file diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md new file mode 100644 index 000000000000..a99832787847 --- /dev/null +++ b/docs/source/en/api/pipelines/glm_image.md @@ -0,0 +1,95 @@ + + +# GLM-Image + +## Overview + +GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios. + +Model architecture: a hybrid autoregressive + diffusion decoder design、 + ++ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library. ++ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images. + +Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality. + ++ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness. ++ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering. + +GLM-Image supports both text-to-image and image-to-image generation within a single model + ++ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios. ++ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects. + +This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image). + +## Usage examples + +### Text to Image Generation + +```python +import torch +from diffusers.pipelines.glm_image import GlmImagePipeline + +pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") +prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy." +image = pipe( + prompt=prompt, + height=32 * 32, + width=36 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), +).images[0] + +image.save("output_t2i.png") +``` + +### Image to Image Generation + +```python +import torch +from diffusers.pipelines.glm_image import GlmImagePipeline +from PIL import Image + +pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") +image_path = "cond.jpg" +prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator." +image = Image.open(image_path).convert("RGB") +image = pipe( + prompt=prompt, + image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1] + height=33 * 32, + width=32 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), +).images[0] + +image.save("output_i2i.png") +``` + ++ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model. + +## GlmImagePipeline + +[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline + - all + - __call__ + +## GlmImagePipelineOutput + +[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput diff --git a/docs/source/en/api/pipelines/kandinsky5_image.md b/docs/source/en/api/pipelines/kandinsky5_image.md index e30a1e3ee529..1125e1594b03 100644 --- a/docs/source/en/api/pipelines/kandinsky5_image.md +++ b/docs/source/en/api/pipelines/kandinsky5_image.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. [Kandinsky 5.0](https://arxiv.org/abs/2511.14993) is a family of diffusion models for Video & Image generation. -Kandinsky 5.0 Image Lite is a lightweight image generation model (6B parameters) +Kandinsky 5.0 Image Lite is a lightweight image generation model (6B parameters). The model introduces several key innovations: - **Latent diffusion pipeline** with **Flow Matching** for improved training stability @@ -21,10 +21,14 @@ The model introduces several key innovations: The original codebase can be found at [kandinskylab/Kandinsky-5](https://github.com/kandinskylab/Kandinsky-5). +> [!TIP] +> Check out the [Kandinsky Lab](https://huggingface.co/kandinskylab) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants. + ## Available Models Kandinsky 5.0 Image Lite: + | model_id | Description | Use Cases | |------------|-------------|-----------| | [**kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers) | 6B image Supervised Fine-Tuned model | Highest generation quality | diff --git a/docs/source/en/api/pipelines/kandinsky5_video.md b/docs/source/en/api/pipelines/kandinsky5_video.md index d7bc76c9bfbe..733e2481732a 100644 --- a/docs/source/en/api/pipelines/kandinsky5_video.md +++ b/docs/source/en/api/pipelines/kandinsky5_video.md @@ -30,6 +30,7 @@ The original codebase can be found at [kandinskylab/Kandinsky-5](https://github. ## Available Models Kandinsky 5.0 T2V Pro: + | model_id | Description | Use Cases | |------------|-------------|-----------| | **kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers** | 5 second Text-to-Video Pro model | High-quality text-to-video generation | diff --git a/docs/source/en/api/pipelines/longcat_image.md b/docs/source/en/api/pipelines/longcat_image.md new file mode 100644 index 000000000000..a7e8a7a3712e --- /dev/null +++ b/docs/source/en/api/pipelines/longcat_image.md @@ -0,0 +1,114 @@ + + +# LongCat-Image + +
+ LoRA +
+ + +We introduce LongCat-Image, a pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models. + + +### Key Features +- 🌟 **Exceptional Efficiency and Performance**: With only **6B parameters**, LongCat-Image surpasses numerous open-source models that are several times larger across multiple benchmarks, demonstrating the immense potential of efficient model design. +- 🌟 **Superior Editing Performance**: LongCat-Image-Edit model achieves state-of-the-art performance among open-source models, delivering leading instruction-following and image quality with superior visual consistency. +- 🌟 **Powerful Chinese Text Rendering**: LongCat-Image demonstrates superior accuracy and stability in rendering common Chinese characters compared to existing SOTA open-source models and achieves industry-leading coverage of the Chinese dictionary. +- 🌟 **Remarkable Photorealism**: Through an innovative data strategy and training framework, LongCat-Image achieves remarkable photorealism in generated images. +- 🌟 **Comprehensive Open-Source Ecosystem**: We provide a complete toolchain, from intermediate checkpoints to full training code, significantly lowering the barrier for further research and development. + +For more details, please refer to the comprehensive [***LongCat-Image Technical Report***](https://arxiv.org/abs/2412.11963) + + +## Usage Example + +```py +import torch +import diffusers +from diffusers import LongCatImagePipeline + +weight_dtype = torch.bfloat16 +pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16 ) +pipe.to('cuda') +# pipe.enable_model_cpu_offload() + +prompt = '一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。' +image = pipe( + prompt, + height=768, + width=1344, + guidance_scale=4.0, + num_inference_steps=50, + num_images_per_prompt=1, + generator=torch.Generator("cpu").manual_seed(43), + enable_cfg_renorm=True, + enable_prompt_rewrite=True, +).images[0] +image.save(f'./longcat_image_t2i_example.png') +``` + + +This pipeline was contributed by LongCat-Image Team. The original codebase can be found [here](https://github.com/meituan-longcat/LongCat-Image). + +Available models: +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelsTypeDescriptionDownload Link
LongCat‑ImageText‑to‑ImageFinal Release. The standard model for out‑of‑the‑box inference. + 🤗 Huggingface +
LongCat‑Image‑DevText‑to‑ImageDevelopment. Mid-training checkpoint, suitable for fine-tuning. + 🤗 Huggingface +
LongCat‑Image‑EditImage EditingSpecialized model for image editing. + 🤗 Huggingface +
+
+ +## LongCatImagePipeline + +[[autodoc]] LongCatImagePipeline +- all +- __call__ + +## LongCatImagePipelineOutput + +[[autodoc]] pipelines.longcat_image.pipeline_output.LongCatImagePipelineOutput + + + diff --git a/docs/source/en/api/pipelines/ltx2.md b/docs/source/en/api/pipelines/ltx2.md new file mode 100644 index 000000000000..24776b42309e --- /dev/null +++ b/docs/source/en/api/pipelines/ltx2.md @@ -0,0 +1,220 @@ + + +# LTX-2 + +
+ LoRA +
+ +LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution. + +You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization. + +The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2). + +## Two-stages Generation +Recommended pipeline to achieve production quality generation, this pipeline is composed of two stages: + +- Stage 1: Generate a video at the target resolution using diffusion sampling with classifier-free guidance (CFG). This stage produces a coherent low-noise video sequence that respects the text/image conditioning. +- Stage 2: Upsample the Stage 1 output by 2 and refine details using a distilled LoRA model to improve fidelity and visual quality. Stage 2 may apply lighter CFG to preserve the structure from Stage 1 while enhancing texture and sharpness. + +Sample usage of text-to-video two stages pipeline + +```py +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda:0" +width = 768 +height = 512 + +pipe = LTX2Pipeline.from_pretrained( + "Lightricks/LTX-2", torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +# Stage 1 default (non-distilled) inference +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + sigmas=None, + guidance_scale=4.0, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + "Lightricks/LTX-2", + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +# Load Stage 2 distilled LoRA +pipe.load_lora_weights( + "Lightricks/LTX-2", adapter_name="stage_2_distilled", weight_name="ltx-2-19b-distilled-lora-384.safetensors" +) +pipe.set_adapters("stage_2_distilled", 1.0) +# VAE tiling is usually necessary to avoid OOM error when VAE decoding +pipe.vae.enable_tiling() +# Change scheduler to use Stage 2 distilled sigmas as is +new_scheduler = FlowMatchEulerDiscreteScheduler.from_config( + pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None +) +pipe.scheduler = new_scheduler +# Stage 2 inference with distilled LoRA and sigmas +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L218 + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_lora_distilled_sample.mp4", +) +``` + +## Distilled checkpoint generation +Fastest two-stages generation pipeline using a distilled checkpoint. + +```py +import torch +from diffusers.pipelines.ltx2 import LTX2Pipeline, LTX2LatentUpsamplePipeline +from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel +from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES +from diffusers.pipelines.ltx2.export_utils import encode_video + +device = "cuda" +width = 768 +height = 512 +random_seed = 42 +generator = torch.Generator(device).manual_seed(random_seed) +model_path = "rootonchair/LTX-2-19b-distilled" + +pipe = LTX2Pipeline.from_pretrained( + model_path, torch_dtype=torch.bfloat16 +) +pipe.enable_sequential_cpu_offload(device=device) + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +frame_rate = 24.0 +video_latent, audio_latent = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=8, + sigmas=DISTILLED_SIGMA_VALUES, + guidance_scale=1.0, + generator=generator, + output_type="latent", + return_dict=False, +) + +latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + model_path, + subfolder="latent_upsampler", + torch_dtype=torch.bfloat16, +) +upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) +upsample_pipe.enable_model_cpu_offload(device=device) +upscaled_video_latent = upsample_pipe( + latents=video_latent, + output_type="latent", + return_dict=False, +)[0] + +video, audio = pipe( + latents=upscaled_video_latent, + audio_latents=audio_latent, + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=3, + noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0], # renoise with first sigma value https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/distilled.py#L178 + sigmas=STAGE_2_DISTILLED_SIGMA_VALUES, + generator=generator, + guidance_scale=1.0, + output_type="np", + return_dict=False, +) +video = (video * 255).round().astype("uint8") +video = torch.from_numpy(video) + +encode_video( + video[0], + fps=frame_rate, + audio=audio[0].float().cpu(), + audio_sample_rate=pipe.vocoder.config.output_sampling_rate, + output_path="ltx2_distilled_sample.mp4", +) +``` + +## LTX2Pipeline + +[[autodoc]] LTX2Pipeline + - all + - __call__ + +## LTX2ImageToVideoPipeline + +[[autodoc]] LTX2ImageToVideoPipeline + - all + - __call__ + +## LTX2LatentUpsamplePipeline + +[[autodoc]] LTX2LatentUpsamplePipeline + - all + - __call__ + +## LTX2PipelineOutput + +[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 940144538a35..68658f41dabc 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24) - The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`. - For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality. - For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`. - - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video. + - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video. - LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined. @@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
Show example code - + ```python import torch from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline @@ -474,6 +474,12 @@ export_to_video(video, "output.mp4", fps=24)
+## LTXI2VLongMultiPromptPipeline + +[[autodoc]] LTXI2VLongMultiPromptPipeline + - all + - __call__ + ## LTXPipeline [[autodoc]] LTXPipeline diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md index b3dd3dd93618..ee3dd3b28e4d 100644 --- a/docs/source/en/api/pipelines/qwenimage.md +++ b/docs/source/en/api/pipelines/qwenimage.md @@ -95,7 +95,7 @@ image.save("qwen_fewsteps.png") With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference. -``` +```py import torch from PIL import Image from diffusers import QwenImageEditPlusPipeline @@ -108,12 +108,46 @@ pipe = QwenImageEditPlusPipeline.from_pretrained( image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg") image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png") image = pipe( - image=[image_1, image_2], - prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', + image=[image_1, image_2], + prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''', num_inference_steps=50 ).images[0] ``` +## Performance + +### torch.compile + +Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s): + +```python +import torch +from diffusers import QwenImagePipeline + +pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda") +pipe.transformer = torch.compile(pipe.transformer) + +# First call triggers compilation (~7s overhead) +# Subsequent calls run at ~2.4x faster +image = pipe("a cat", num_inference_steps=50).images[0] +``` + +### Batched Inference with Variable-Length Prompts + +When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output. + +```python +# CFG with different prompt lengths works correctly +image = pipe( + prompt="A cat", + negative_prompt="blurry, low quality, distorted", + true_cfg_scale=3.5, + num_inference_steps=50, +).images[0] +``` + +For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f). + ## QwenImagePipeline [[autodoc]] QwenImagePipeline diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 6730f1551607..e1829bc409eb 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -37,7 +37,8 @@ The following SkyReels-V2 models are supported in Diffusers: - [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers) - [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers) - [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers) -- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers) + +This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz). > [!TIP] > Click on the SkyReels-V2 models in the right sidebar for more examples of video generation. diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index 6aab6c5b33b9..d5fdbbfe0f95 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color. - - - ### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication [Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team. diff --git a/docs/source/en/api/pipelines/z_image.md b/docs/source/en/api/pipelines/z_image.md new file mode 100644 index 000000000000..5175f6b0fb6f --- /dev/null +++ b/docs/source/en/api/pipelines/z_image.md @@ -0,0 +1,66 @@ + + +# Z-Image + +
+ LoRA +
+ +[Z-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released: + +|Model|Hugging Face| +|---|---| +|Z-Image-Turbo|https://huggingface.co/Tongyi-MAI/Z-Image-Turbo| + +## Z-Image-Turbo + +Z-Image-Turbo is a distilled version of Z-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence. + +## Image-to-image + +Use [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt. + +```python +import torch +from diffusers import ZImageImg2ImgPipeline +from diffusers.utils import load_image + +pipe = ZImageImg2ImgPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +init_image = load_image(url).resize((1024, 1024)) + +prompt = "A fantasy landscape with mountains and a river, detailed, vibrant colors" +image = pipe( + prompt, + image=init_image, + strength=0.6, + num_inference_steps=9, + guidance_scale=0.0, + generator=torch.Generator("cuda").manual_seed(42), +).images[0] +image.save("zimage_img2img.png") +``` + +## ZImagePipeline + +[[autodoc]] ZImagePipeline + - all + - __call__ + +## ZImageImg2ImgPipeline + +[[autodoc]] ZImageImg2ImgPipeline + - all + - __call__ diff --git a/docs/source/en/hybrid_inference/api_reference.md b/docs/source/en/hybrid_inference/api_reference.md index 865aaba5ebb6..b538cb350481 100644 --- a/docs/source/en/hybrid_inference/api_reference.md +++ b/docs/source/en/hybrid_inference/api_reference.md @@ -1,9 +1,11 @@ -# Hybrid Inference API Reference +# Remote inference -## Remote Decode +Remote inference provides access to an [Inference Endpoint](https://huggingface.co/docs/inference-endpoints/index) to offload local generation requirements for decoding and encoding. + +## remote_decode [[autodoc]] utils.remote_utils.remote_decode -## Remote Encode +## remote_encode [[autodoc]] utils.remote_utils.remote_encode diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index 7ed1bbb88b3f..1384be9b7348 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -10,51 +10,296 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Hybrid Inference +# Remote inference -**Empowering local AI builders with Hybrid Inference** +> [!TIP] +> This is currently an experimental feature, and if you have any feedback, please feel free to leave it [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). +Remote inference offloads the decoding and encoding process to a remote endpoint to relax the memory requirements for local inference with large models. This feature is powered by [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index). Refer to the table below for the supported models and endpoint. -> [!TIP] -> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae). -> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). +| Model | Endpoint | Checkpoint | Support | +|---|---|---|---| +| Stable Diffusion v1 | https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud | [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) | encode/decode | +| Stable Diffusion XL | https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud | [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) | encode/decode | +| Flux | https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | encode/decode | +| HunyuanVideo | https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud | [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | decode | + +This guide will show you how to encode and decode latents with remote inference. + +## Encoding + +Encoding converts images and videos into latent representations. Refer to the table below for the supported VAEs. + +Pass an image to [`~utils.remote_encode`] to encode it. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference. + +```py +import torch +from diffusers import FluxPipeline +from diffusers.utils import load_image +from diffusers.utils.remote_utils import remote_encode + +pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.float16, + vae=None, + device_map="cuda" +) + +init_image = load_image( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" +) +init_image = init_image.resize((768, 512)) + +init_latent = remote_encode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud", + image=init_image, + scaling_factor=0.3611, + shift_factor=0.1159 +) +``` + +## Decoding + +Decoding converts latent representations back into images or videos. Refer to the table below for the available and supported VAEs. + +Set the output type to `"latent"` in the pipeline and set the `vae` to `None`. Pass the latents to the [`~utils.remote_decode`] function. For Flux, the latents are packed so the `height` and `width` also need to be passed. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference. + + + + +```py +from diffusers import FluxPipeline + +pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16, + vae=None, + device_map="cuda" +) + +prompt = """ +A photorealistic Apollo-era photograph of a cat in a small astronaut suit with a bubble helmet, standing on the Moon and holding a flagpole planted in the dusty lunar soil. The flag shows a colorful paw-print emblem. Earth glows in the black sky above the stark gray surface, with sharp shadows and high-contrast lighting like vintage NASA photos. +""" + +latent = pipeline( + prompt=prompt, + guidance_scale=0.0, + num_inference_steps=4, + output_type="latent", +).images +image = remote_decode( + endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + height=1024, + width=1024, + scaling_factor=0.3611, + shift_factor=0.1159, +) +image.save("image.jpg") +``` + + + + +```py +import torch +from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + +transformer = HunyuanVideoTransformer3DModel.from_pretrained( + "hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16 +) +pipeline = HunyuanVideoPipeline.from_pretrained( + model_id, transformer=transformer, vae=None, torch_dtype=torch.float16, device_map="cuda" +) + +latent = pipeline( + prompt="A cat walks on the grass, realistic", + height=320, + width=512, + num_frames=61, + num_inference_steps=30, + output_type="latent", +).frames + +video = remote_decode( + endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=latent, + output_type="mp4", +) + +if isinstance(video, bytes): + with open("video.mp4", "wb") as f: + f.write(video) +``` + + + + +## Queuing + +Remote inference supports queuing to process multiple generation requests. While the current latent is being decoded, you can queue the next prompt. + +```py +import queue +import threading +from IPython.display import display +from diffusers import StableDiffusionXLPipeline + +def decode_worker(q: queue.Queue): + while True: + item = q.get() + if item is None: + break + image = remote_decode( + endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", + tensor=item, + scaling_factor=0.13025, + ) + display(image) + q.task_done() + +q = queue.Queue() +thread = threading.Thread(target=decode_worker, args=(q,), daemon=True) +thread.start() + +def decode(latent: torch.Tensor): + q.put(latent) + +prompts = [ + "A grainy Apollo-era style photograph of a cat in a snug astronaut suit with a bubble helmet, standing on the lunar surface and gripping a flag with a paw-print emblem. The gray Moon landscape stretches behind it, Earth glowing vividly in the black sky, shadows crisp and high-contrast.", + "A vintage 1960s sci-fi pulp magazine cover illustration of a heroic cat astronaut planting a flag on the Moon. Bold, saturated colors, exaggerated space gear, playful typography floating in the background, Earth painted in bright blues and greens.", + "A hyper-detailed cinematic shot of a cat astronaut on the Moon holding a fluttering flag, fur visible through the helmet glass, lunar dust scattering under its feet. The vastness of space and Earth in the distance create an epic, awe-inspiring tone.", + "A colorful cartoon drawing of a happy cat wearing a chunky, oversized spacesuit, proudly holding a flag with a big paw print on it. The Moon’s surface is simplified with craters drawn like doodles, and Earth in the sky has a smiling face.", + "A monochrome 1969-style press photo of a “first cat on the Moon” moment. The cat, in a tiny astronaut suit, stands by a planted flag, with grainy textures, scratches, and a blurred Earth in the background, mimicking old archival space photos." +] + + +pipeline = StableDiffusionXLPipeline.from_pretrained( + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + vae=None, + device_map="cuda" +) + +pipeline.unet = pipeline.unet.to(memory_format=torch.channels_last) +pipeline.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +_ = pipeline( + prompt=prompts[0], + output_type="latent", +) + +for prompt in prompts: + latent = pipeline( + prompt=prompt, + output_type="latent", + ).images + decode(latent) + +q.put(None) +thread.join() +``` + +## Benchmarks + +The tables demonstrate the memory requirements for encoding and decoding with Stable Diffusion v1.5 and SDXL on different GPUs. +For the majority of these GPUs, the memory usage dictates whether other models (text encoders, UNet/transformer) need to be offloaded or required tiled encoding. The latter two techniques increases inference time and impacts quality. +
Encoding - Stable Diffusion v1.5 -## Why use Hybrid Inference? +| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 | -Hybrid Inference offers a fast and simple way to offload local generation requirements. +
-- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware. -- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance. -- 💰 **Cost Effective:** It's free! 🤑 -- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community. -- 🔧 **Developer-Friendly:** Simple requests, fast responses. +
Encoding SDXL ---- +| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | +|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:| +| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 | +| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 | +| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 | +| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 | +| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 | +| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 | +| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 | +| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 | +| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 | +| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 | +| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 | +| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 | +| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 | -## Available Models +
-* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed. -* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training. -* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow. +
Decoding - Stable Diffusion v1.5 ---- +| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | +| --- | --- | --- | --- | --- | --- | +| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% | +| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% | +| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% | +| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% | +| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% | -## Integrations +
-* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. -* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. +
Decoding SDXL -## Changelog +| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | +| --- | --- | --- | --- | --- | --- | +| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% | +| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% | +| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% | +| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% | +| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% | +| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% | +| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% | +| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% | +| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% | +| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% | +| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% | +| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% | -- March 10 2025: Added VAE encode -- March 2 2025: Initial release with VAE decoding +
-## Contents -The documentation is organized into three sections: +## Resources -* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference. -* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference. -* **API Reference** Dive into task-specific settings and parameters. +- Remote inference is also supported in [SD.Next](https://github.com/vladmandic/sdnext) and [ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae). +- Refer to the [Remote VAEs for decoding with Inference Endpoints](https://huggingface.co/blog/remote_vae) blog post to learn more. \ No newline at end of file diff --git a/docs/source/en/hybrid_inference/vae_decode.md b/docs/source/en/hybrid_inference/vae_decode.md deleted file mode 100644 index 1457090550c7..000000000000 --- a/docs/source/en/hybrid_inference/vae_decode.md +++ /dev/null @@ -1,345 +0,0 @@ -# Getting Started: VAE Decode with Hybrid Inference - -VAE decode is an essential component of diffusion models - turning latent representations into images or videos. - -## Memory - -These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs. - -For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality. - -
SD v1.5 - -| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | -| --- | --- | --- | --- | --- | --- | -| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% | -| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% | -| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% | -| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% | -| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% | -| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% | -| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% | -| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% | -| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% | -| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% | -| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% | -| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% | - -
- -
SDXL - -| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | -| --- | --- | --- | --- | --- | --- | -| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% | -| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% | -| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% | -| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% | -| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% | -| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% | -| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% | -| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% | -| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% | -| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% | -| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% | -| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% | - -
- -## Available VAEs - -| | **Endpoint** | **Model** | -|:-:|:-----------:|:--------:| -| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | -| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | -| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | -| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) | - - -> [!TIP] -> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). - - -## Code - -> [!TIP] -> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main` - - -A helper method simplifies interacting with Hybrid Inference. - -```python -from diffusers.utils.remote_utils import remote_decode -``` - -### Basic example - -Here, we show how to use the remote VAE on random tensors. - -
Code - -```python -image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16), - scaling_factor=0.18215, -) -``` - -
- -
- -
- -Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`. - -
Code - -```python -image = remote_decode( - endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=torch.randn([1, 4096, 64], dtype=torch.float16), - height=1024, - width=1024, - scaling_factor=0.3611, - shift_factor=0.1159, -) -``` - -
- -
- -
- -Finally, an example for HunyuanVideo. - -
Code - -```python -video = remote_decode( - endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16), - output_type="mp4", -) -with open("video.mp4", "wb") as f: - f.write(video) -``` - -
- -
- -
- - -### Generation - -But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5. - -
Code - -```python -from diffusers import StableDiffusionPipeline - -pipe = StableDiffusionPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16, - variant="fp16", - vae=None, -).to("cuda") - -prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" - -latent = pipe( - prompt=prompt, - output_type="latent", -).images -image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - scaling_factor=0.18215, -) -image.save("test.jpg") -``` - -
- -
- -
- -Here’s another example with Flux. - -
Code - -```python -from diffusers import FluxPipeline - -pipe = FluxPipeline.from_pretrained( - "black-forest-labs/FLUX.1-schnell", - torch_dtype=torch.bfloat16, - vae=None, -).to("cuda") - -prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious" - -latent = pipe( - prompt=prompt, - guidance_scale=0.0, - num_inference_steps=4, - output_type="latent", -).images -image = remote_decode( - endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - height=1024, - width=1024, - scaling_factor=0.3611, - shift_factor=0.1159, -) -image.save("test.jpg") -``` - -
- -
- -
- -Here’s an example with HunyuanVideo. - -
Code - -```python -from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel - -model_id = "hunyuanvideo-community/HunyuanVideo" -transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=torch.bfloat16 -) -pipe = HunyuanVideoPipeline.from_pretrained( - model_id, transformer=transformer, vae=None, torch_dtype=torch.float16 -).to("cuda") - -latent = pipe( - prompt="A cat walks on the grass, realistic", - height=320, - width=512, - num_frames=61, - num_inference_steps=30, - output_type="latent", -).frames - -video = remote_decode( - endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - output_type="mp4", -) - -if isinstance(video, bytes): - with open("video.mp4", "wb") as f: - f.write(video) -``` - -
- -
- -
- - -### Queueing - -One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency. - - -
Code - -```python -import queue -import threading -from IPython.display import display -from diffusers import StableDiffusionPipeline - -def decode_worker(q: queue.Queue): - while True: - item = q.get() - if item is None: - break - image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=item, - scaling_factor=0.18215, - ) - display(image) - q.task_done() - -q = queue.Queue() -thread = threading.Thread(target=decode_worker, args=(q,), daemon=True) -thread.start() - -def decode(latent: torch.Tensor): - q.put(latent) - -prompts = [ - "Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious", - "Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore", - "Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.", - "Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP", - "A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting", - "Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,", -] - -pipe = StableDiffusionPipeline.from_pretrained( - "Lykon/dreamshaper-8", - torch_dtype=torch.float16, - vae=None, -).to("cuda") - -pipe.unet = pipe.unet.to(memory_format=torch.channels_last) -pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) - -_ = pipe( - prompt=prompts[0], - output_type="latent", -) - -for prompt in prompts: - latent = pipe( - prompt=prompt, - output_type="latent", - ).images - decode(latent) - -q.put(None) -thread.join() -``` - -
- - -
- -
- -## Integrations - -* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. -* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. diff --git a/docs/source/en/hybrid_inference/vae_encode.md b/docs/source/en/hybrid_inference/vae_encode.md deleted file mode 100644 index dd285fa25c03..000000000000 --- a/docs/source/en/hybrid_inference/vae_encode.md +++ /dev/null @@ -1,183 +0,0 @@ -# Getting Started: VAE Encode with Hybrid Inference - -VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations. - -## Memory - -These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs. - -For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality. - -
SD v1.5 - -| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) | -|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:| -| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 | -| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 | -| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 | -| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 | -| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 | -| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 | -| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 | -| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 | -| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 | -| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 | -| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 | -| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 | -| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 | -| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 | -| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 | -| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 | -| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 | -| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 | -| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 | -| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 | - - -
- -
SDXL - -| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) | -|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:| -| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 | -| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 | -| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 | -| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 | -| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 | -| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 | -| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 | -| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 | -| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 | -| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 | -| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 | -| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 | -| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 | -| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 | -| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 | -| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 | -| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 | -| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 | -| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 | -| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 | - -
- -## Available VAEs - -| | **Endpoint** | **Model** | -|:-:|:-----------:|:--------:| -| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) | -| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | -| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | - - -> [!TIP] -> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml). - - -## Code - -> [!TIP] -> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main` - - -A helper method simplifies interacting with Hybrid Inference. - -```python -from diffusers.utils.remote_utils import remote_encode -``` - -### Basic example - -Let's encode an image, then decode it to demonstrate. - -
- -
- -
Code - -```python -from diffusers.utils import load_image -from diffusers.utils.remote_utils import remote_decode - -image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true") - -latent = remote_encode( - endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/", - scaling_factor=0.3611, - shift_factor=0.1159, -) - -decoded = remote_decode( - endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - scaling_factor=0.3611, - shift_factor=0.1159, -) -``` - -
- -
- -
- - -### Generation - -Now let's look at a generation example, we'll encode the image, generate then remotely decode too! - -
Code - -```python -import torch -from diffusers import StableDiffusionImg2ImgPipeline -from diffusers.utils import load_image -from diffusers.utils.remote_utils import remote_decode, remote_encode - -pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - "stable-diffusion-v1-5/stable-diffusion-v1-5", - torch_dtype=torch.float16, - variant="fp16", - vae=None, -).to("cuda") - -init_image = load_image( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" -) -init_image = init_image.resize((768, 512)) - -init_latent = remote_encode( - endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/", - image=init_image, - scaling_factor=0.18215, -) - -prompt = "A fantasy landscape, trending on artstation" -latent = pipe( - prompt=prompt, - image=init_latent, - strength=0.75, - output_type="latent", -).images - -image = remote_decode( - endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/", - tensor=latent, - scaling_factor=0.18215, -) -image.save("fantasy_landscape.jpg") -``` - -
- -
- -
- -## Integrations - -* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference. -* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference. diff --git a/docs/source/en/modular_diffusers/custom_blocks.md b/docs/source/en/modular_diffusers/custom_blocks.md index 1c311582264e..b412e0e58abc 100644 --- a/docs/source/en/modular_diffusers/custom_blocks.md +++ b/docs/source/en/modular_diffusers/custom_blocks.md @@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License. [ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block. > [!TIP] -> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana. +> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom blocks. ## Project Structure @@ -31,54 +31,58 @@ Your custom block project should use the following structure: - `block.py` contains the custom block implementation - `modular_config.json` contains the metadata needed to load the block -## Example: Florence 2 Inpainting Block +## Quick Start with Template -In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting. +The fastest way to create a custom block is to start from our template. The template provides a pre-configured project structure with `block.py` and `modular_config.json` files, plus commented examples showing how to define components, inputs, outputs, and the `__call__` method—so you can focus on your custom logic instead of boilerplate setup. -The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub. +### Download the template -```py -# Inside block.py -from diffusers.modular_pipelines import ( - ModularPipelineBlocks, - ComponentSpec, +```python +from diffusers import ModularPipelineBlocks + +model_id = "diffusers/custom-block-template" +local_dir = model_id.split("/")[-1] + +blocks = ModularPipelineBlocks.from_pretrained( + model_id, + trust_remote_code=True, + local_dir=local_dir ) -from transformers import AutoProcessor, Florence2ForConditionalGeneration +``` +This saves the template files to `custom-block-template/` locally or you could use `local_dir` to save to a specific location. -class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): +### Edit locally - @property - def expected_components(self): - return [ - ComponentSpec( - name="image_annotator", - type_hint=Florence2ForConditionalGeneration, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ComponentSpec( - name="image_annotator_processor", - type_hint=AutoProcessor, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ] +Open `block.py` and implement your custom block. The template includes commented examples showing how to define each property. See the [Florence-2 example](#example-florence-2-image-annotator) below for a complete implementation. + +### Test your block + +```python +from diffusers import ModularPipelineBlocks + +blocks = ModularPipelineBlocks.from_pretrained(local_dir, trust_remote_code=True) +pipeline = blocks.init_pipeline() +output = pipeline(...) # your inputs here ``` -Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations. +### Upload to the Hub -```py -from typing import List, Union -from PIL import Image, ImageDraw -import torch -import numpy as np - -from diffusers.modular_pipelines import ( - PipelineState, - ModularPipelineBlocks, - InputParam, - ComponentSpec, - OutputParam, -) +```python +pipeline.save_pretrained(local_dir, repo_id="your-username/your-block-name", push_to_hub=True) +``` + +## Example: Florence-2 Image Annotator + +This example creates a custom block with [Florence-2](https://huggingface.co/docs/transformers/model_doc/florence2) to process an input image and generate a mask for inpainting. + +### Define components + +Define the components the block needs, `Florence2ForConditionalGeneration` and its processor. When defining components, specify the `name` (how you'll access it in code), `type_hint` (the model class), and `pretrained_model_name_or_path` (where to load weights from). + +```python +# Inside block.py +from diffusers.modular_pipelines import ModularPipelineBlocks, ComponentSpec from transformers import AutoProcessor, Florence2ForConditionalGeneration @@ -98,122 +102,21 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): pretrained_model_name_or_path="florence-community/Florence-2-base-ft", ), ] - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam( - "image", - type_hint=Union[Image.Image, List[Image.Image]], - required=True, - description="Image(s) to annotate", - ), - InputParam( - "annotation_task", - type_hint=Union[str, List[str]], - required=True, - default="", - description="""Annotation Task to perform on the image. - Supported Tasks: - - - - - - - - - - - """, - ), - InputParam( - "annotation_prompt", - type_hint=Union[str, List[str]], - required=True, - description="""Annotation Prompt to provide more context to the task. - Can be used to detect or segment out specific elements in the image - """, - ), - InputParam( - "annotation_output_type", - type_hint=str, - required=True, - default="mask_image", - description="""Output type from annotation predictions. Availabe options are - mask_image: - -black and white mask image for the given image based on the task type - mask_overlay: - - mask overlayed on the original image - bounding_box: - - bounding boxes drawn on the original image - """, - ), - InputParam( - "annotation_overlay", - type_hint=bool, - required=True, - default=False, - description="", - ), - ] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - "mask_image", - type_hint=Image, - description="Inpainting Mask for input Image(s)", - ), - OutputParam( - "annotations", - type_hint=dict, - description="Annotations Predictions for input Image(s)", - ), - OutputParam( - "image", - type_hint=Image, - description="Annotated input Image(s)", - ), - ] - ``` -Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask. +### Define inputs and outputs -```py +Inputs include the image, annotation task, and prompt. Outputs include the generated mask and annotations. + +```python from typing import List, Union -from PIL import Image, ImageDraw -import torch -import numpy as np - -from diffusers.modular_pipelines import ( - PipelineState, - ModularPipelineBlocks, - InputParam, - ComponentSpec, - OutputParam, -) -from transformers import AutoProcessor, Florence2ForConditionalGeneration +from PIL import Image +from diffusers.modular_pipelines import InputParam, OutputParam class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): - @property - def expected_components(self): - return [ - ComponentSpec( - name="image_annotator", - type_hint=Florence2ForConditionalGeneration, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ComponentSpec( - name="image_annotator_processor", - type_hint=AutoProcessor, - pretrained_model_name_or_path="florence-community/Florence-2-base-ft", - ), - ] + # ... expected_components from above ... @property def inputs(self) -> List[InputParam]: @@ -226,51 +129,21 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): ), InputParam( "annotation_task", - type_hint=Union[str, List[str]], - required=True, + type_hint=str, default="", - description="""Annotation Task to perform on the image. - Supported Tasks: - - - - - - - - - - - """, + description="Annotation task to perform (e.g., , , )", ), InputParam( "annotation_prompt", - type_hint=Union[str, List[str]], + type_hint=str, required=True, - description="""Annotation Prompt to provide more context to the task. - Can be used to detect or segment out specific elements in the image - """, + description="Prompt to provide context for the annotation task", ), InputParam( "annotation_output_type", type_hint=str, - required=True, default="mask_image", - description="""Output type from annotation predictions. Availabe options are - mask_image: - -black and white mask image for the given image based on the task type - mask_overlay: - - mask overlayed on the original image - bounding_box: - - bounding boxes drawn on the original image - """, - ), - InputParam( - "annotation_overlay", - type_hint=bool, - required=True, - default=False, - description="", + description="Output type: 'mask_image', 'mask_overlay', or 'bounding_box'", ), ] @@ -279,109 +152,45 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): return [ OutputParam( "mask_image", - type_hint=Image, - description="Inpainting Mask for input Image(s)", + type_hint=Image.Image, + description="Inpainting mask for the input image", ), OutputParam( "annotations", type_hint=dict, - description="Annotations Predictions for input Image(s)", + description="Raw annotation predictions", ), OutputParam( "image", - type_hint=Image, - description="Annotated input Image(s)", + type_hint=Image.Image, + description="Annotated image", ), ] +``` - def get_annotations(self, components, images, prompts, task): - task_prompts = [task + prompt for prompt in prompts] +### Implement the `__call__` method - inputs = components.image_annotator_processor( - text=task_prompts, images=images, return_tensors="pt" - ).to(components.image_annotator.device, components.image_annotator.dtype) +The `__call__` method contains the block's logic. Access inputs via `block_state`, run your computation, and set outputs back to `block_state`. - generated_ids = components.image_annotator.generate( - input_ids=inputs["input_ids"], - pixel_values=inputs["pixel_values"], - max_new_tokens=1024, - early_stopping=False, - do_sample=False, - num_beams=3, - ) - annotations = components.image_annotator_processor.batch_decode( - generated_ids, skip_special_tokens=False - ) - outputs = [] - for image, annotation in zip(images, annotations): - outputs.append( - components.image_annotator_processor.post_process_generation( - annotation, task=task, image_size=(image.width, image.height) - ) - ) - return outputs - - def prepare_mask(self, images, annotations, overlay=False, fill="white"): - masks = [] - for image, annotation in zip(images, annotations): - mask_image = image.copy() if overlay else Image.new("L", image.size, 0) - draw = ImageDraw.Draw(mask_image) - - for _, _annotation in annotation.items(): - if "polygons" in _annotation: - for polygon in _annotation["polygons"]: - polygon = np.array(polygon).reshape(-1, 2) - if len(polygon) < 3: - continue - polygon = polygon.reshape(-1).tolist() - draw.polygon(polygon, fill=fill) - - elif "bbox" in _annotation: - bbox = _annotation["bbox"] - draw.rectangle(bbox, fill="white") - - masks.append(mask_image) - - return masks - - def prepare_bounding_boxes(self, images, annotations): - outputs = [] - for image, annotation in zip(images, annotations): - image_copy = image.copy() - draw = ImageDraw.Draw(image_copy) - for _, _annotation in annotation.items(): - bbox = _annotation["bbox"] - label = _annotation["label"] - - draw.rectangle(bbox, outline="red", width=3) - draw.text((bbox[0], bbox[1] - 20), label, fill="red") - - outputs.append(image_copy) - - return outputs - - def prepare_inputs(self, images, prompts): - prompts = prompts or "" - - if isinstance(images, Image.Image): - images = [images] - if isinstance(prompts, str): - prompts = [prompts] - - if len(images) != len(prompts): - raise ValueError("Number of images and annotation prompts must match.") - - return images, prompts +```python +import torch +from diffusers.modular_pipelines import PipelineState + + +class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): + + # ... expected_components, inputs, intermediate_outputs from above ... @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + images, annotation_task_prompt = self.prepare_inputs( block_state.image, block_state.annotation_prompt ) task = block_state.annotation_task fill = block_state.fill - + annotations = self.get_annotations( components, images, annotation_task_prompt, task ) @@ -400,67 +209,69 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): self.set_block_state(state, block_state) return components, state - -``` - -Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines. - - - - -```shell -# In the folder with the `block.py` file, run: -diffusers-cli custom_block -``` - -Then upload the block to the Hub: - -```shell -hf upload . . -``` - - - -```py -from block import Florence2ImageAnnotatorBlock -block = Florence2ImageAnnotatorBlock() -block.push_to_hub("") + + # Helper methods for mask/bounding box generation... ``` - - +> [!TIP] +> See the complete implementation at [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator). ## Using Custom Blocks -Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`. +Load a custom block with [`~ModularPipeline.from_pretrained`] and set `trust_remote_code=True`. ```py import torch -from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks -from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS +from diffusers import ModularPipeline from diffusers.utils import load_image -# Fetch the Florence2 image annotator block that will create our mask -image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True) +# Load the Florence-2 annotator pipeline +image_annotator = ModularPipeline.from_pretrained( + "diffusers/Florence2-image-Annotator", + trust_remote_code=True +) -my_blocks = INPAINT_BLOCKS.copy() -# insert the annotation block before the image encoding step -my_blocks.insert("image_annotator", image_annotator_block, 1) +# Check the docstring to see inputs/outputs +print(image_annotator.blocks.doc) +``` -# Create our initial set of inpainting blocks -blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks) +Use the block to generate a mask: -repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0" -pipe = blocks.init_pipeline(repo_id) -pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True) +```python +image_annotator.load_components(torch_dtype=torch.bfloat16) +image_annotator.to("cuda") -image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true") +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg") image = image.resize((1024, 1024)) - prompt = ["A red car"] annotation_task = "" annotation_prompt = ["the car"] +mask_image = image_annotator_node( + prompt=prompt, + image=image, + annotation_task=annotation_task, + annotation_prompt=annotation_prompt, + annotation_output_type="mask_image", +).images +mask_image[0].save("car-mask.png") +``` + +Compose it with other blocks to create a new pipeline: + +```python +# Get the annotator block +annotator_block = image_annotator.blocks + +# Get an inpainting workflow and insert the annotator at the beginning +inpaint_blocks = ModularPipeline.from_pretrained("Qwen/Qwen-Image").blocks.get_workflow("inpainting") +inpaint_blocks.sub_blocks.insert("image_annotator", annotator_block, 0) + +# Initialize the combined pipeline +pipe = inpaint_blocks.init_pipeline() +pipe.load_components(torch_dtype=torch.float16, device="cuda") + +# Now the pipeline automatically generates masks from prompts output = pipe( prompt=prompt, image=image, @@ -475,18 +286,50 @@ output = pipe( output[0].save("florence-inpainting.png") ``` -## Editing Custom Blocks +## Editing custom blocks -By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder. +Edit custom blocks by downloading it locally. This is the same workflow as the [Quick Start with Template](#quick-start-with-template), but starting from an existing block instead of the template. -```py -import torch -from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks -from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS -from diffusers.utils import load_image +Use the `local_dir` argument to download a custom block to a specific folder: + +```python +from diffusers import ModularPipelineBlocks -# Fetch the Florence2 image annotator block that will create our mask -image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder") +# Download to a local folder for editing +annotator_block = ModularPipelineBlocks.from_pretrained( + "diffusers/Florence2-image-Annotator", + trust_remote_code=True, + local_dir="./my-florence-block" +) ``` -Any changes made to the block files in this folder will be reflected when you load the block again. +Any changes made to the block files in this folder will be reflected when you load the block again. When you're ready to share your changes, upload to a new repository: + +```python +pipeline = annotator_block.init_pipeline() +pipeline.save_pretrained("./my-florence-block", repo_id="your-username/my-custom-florence", push_to_hub=True) +``` + +## Next Steps + + + + +This guide covered creating a single custom block. Learn how to compose multiple blocks together: + +- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to execute in sequence +- [ConditionalPipelineBlocks](./auto_pipeline_blocks): Create conditional blocks that select different execution paths +- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks): Define an iterative workflows like the denoising loop + + + + +Make your custom block work with Mellon's visual interface. See the [Mellon Custom Blocks](./mellon) guide. + + + + +Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks. + + + \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md index a80309de19a6..74a868922799 100644 --- a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md +++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md @@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently. -- It recieves the iteration variable from the loop wrapper. +- It receives the iteration variable from the loop wrapper. - It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`]. - It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`]. diff --git a/docs/source/en/modular_diffusers/mellon.md b/docs/source/en/modular_diffusers/mellon.md new file mode 100644 index 000000000000..808e62ad7966 --- /dev/null +++ b/docs/source/en/modular_diffusers/mellon.md @@ -0,0 +1,270 @@ + + + +## Using Custom Blocks with Mellon + +[Mellon](https://github.com/cubiq/Mellon) is a visual workflow interface that integrates with Modular Diffusers and is designed for node-based workflows. + +> [!WARNING] +> Mellon is in early development and not ready for production use yet. Consider this a sneak peek of how the integration works! + + +Custom blocks work in Mellon out of the box - just need to add a `mellon_pipeline_config.json` to your repository. This config file tells Mellon how to render your block's parameters as UI components. + +Here's what it looks like in action with the [Gemini Prompt Expander](https://huggingface.co/diffusers/gemini-prompt-expander-mellon) block: + +![Mellon custom block demo](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/modular_demo_dynamic.gif) + +To use a modular diffusers custom block in Mellon: +1. Drag a **Dynamic Block Node** from the ModularDiffusers section +2. Enter the `repo_id` (e.g., `diffusers/gemini-prompt-expander-mellon`) +3. Click **Load Custom Block** +4. The node transforms to show your block's inputs and outputs + +Now let's walk through how to create this config for your own custom block. + +## Steps to create a Mellon config + +1. **Specify Mellon types for your parameters** - Each `InputParam`/`OutputParam` needs a type that tells Mellon what UI component to render (e.g., `"textbox"`, `"dropdown"`, `"image"`). +2. **Generate `mellon_pipeline_config.json`** - Use our utility to generate a config template and push it to your Hub repository. +3. **(Optional) Manually adjust the config** - Fine-tune the generated config for your specific needs. + +## Specify Mellon types for parameters + +Mellon types determine how each parameter renders in the UI. If you don't specify a type for a parameter, it will default to `"custom"`, which renders as a simple connection dot. You can always adjust this later in the generated config. + + +| Type | Input/Output | Description | +|------|--------------|-------------| +| `image` | Both | Image (PIL Image) | +| `video` | Both | Video | +| `text` | Both | Text display | +| `textbox` | Input | Text input | +| `dropdown` | Input | Dropdown selection menu | +| `slider` | Input | Slider for numeric values | +| `number` | Input | Numeric input | +| `checkbox` | Input | Boolean toggle | + +For parameters that need more configuration (like dropdowns with options, or sliders with min/max values), pass a `MellonParam` instance directly instead of a string. You can use one of the class methods below, or create a fully custom one with `MellonParam(name, label, type, ...)`. + +| Method | Description | +|--------|-------------| +| `MellonParam.Input.image(name)` | Image input | +| `MellonParam.Input.textbox(name, default)` | Text input as textarea | +| `MellonParam.Input.dropdown(name, options, default)` | Dropdown selection | +| `MellonParam.Input.slider(name, default, min, max, step)` | Slider for numeric values | +| `MellonParam.Input.number(name, default, min, max, step)` | Numeric input (no slider) | +| `MellonParam.Input.seed(name, default)` | Seed input with randomize button | +| `MellonParam.Input.checkbox(name, default)` | Boolean checkbox | +| `MellonParam.Input.model(name)` | Model input for diffusers components | +| `MellonParam.Output.image(name)` | Image output | +| `MellonParam.Output.video(name)` | Video output | +| `MellonParam.Output.text(name)` | Text output | +| `MellonParam.Output.model(name)` | Model output for diffusers components | + +Choose one of the methods below to specify a Mellon type. + +### Using `metadata` in block definitions + +If you're defining a custom block from scratch, add `metadata={"mellon": ""}` directly to your `InputParam` and `OutputParam` definitions. If you're editing an existing custom block from the Hub, see [Editing custom blocks](./custom_blocks#editing-custom-blocks) for how to download it locally. + +```python +class GeminiPromptExpander(ModularPipelineBlocks): + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "prompt", + type_hint=str, + required=True, + description="Prompt to use", + metadata={"mellon": "textbox"}, # Text input + ) + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt", + type_hint=str, + description="Expanded prompt by the LLM", + metadata={"mellon": "text"}, # Text output + ), + OutputParam( + "old_prompt", + type_hint=str, + description="Old prompt provided by the user", + # No metadata - we don't want to render this in UI + ) + ] +``` + +For full control over UI configuration, pass a `MellonParam` instance directly: +```python +from diffusers.modular_pipelines.mellon_node_utils import MellonParam + +InputParam( + "mode", + type_hint=str, + default="balanced", + metadata={"mellon": MellonParam.Input.dropdown("mode", options=["fast", "balanced", "quality"])}, +) +``` + +### Using `input_types` and `output_types` when Generating Config + +If you're working with an existing pipeline or prefer to keep your block definitions clean, specify types when generating the config using the `input_types/output_types` argument: +```python +from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig + +mellon_config = MellonPipelineConfig.from_custom_block( + blocks, + input_types={"prompt": "textbox"}, + output_types={"prompt": "text"} +) +``` + +> [!NOTE] +> When both `metadata` and `input_types`/`output_types` are specified, the arguments overrides `metadata`. + +## Generate and push the Mellon config + +After adding metadata to your block, generate the default Mellon configuration template and push it to the Hub: + +```python +from diffusers import ModularPipelineBlocks +from diffusers.modular_pipelines.mellon_node_utils import MellonPipelineConfig + +# load your custom blocks from your local dir +blocks = ModularPipelineBlocks.from_pretrained("/path/local/folder", trust_remote_code=True) + +# Generate the default config template +mellon_config = MellonPipelineConfig.from_custom_block(blocks) +# push the default template to `repo_id`, you will need to pass the same local folder path so that it will save the config locally first +mellon_config.save( + local_dir="/path/local/folder", + repo_id= repo_id, + push_to_hub=True +) +``` + +This creates a `mellon_pipeline_config.json` file in your repository. + +## Review and adjust the config + +The generated template is a starting point - you may want to adjust it for your needs. Let's walk through the generated config for the Gemini Prompt Expander: + +```json +{ + "label": "Gemini Prompt Expander", + "default_repo": "", + "default_dtype": "", + "node_params": { + "custom": { + "params": { + "prompt": { + "label": "Prompt", + "type": "string", + "display": "textarea", + "default": "" + }, + "out_prompt": { + "label": "Prompt", + "type": "string", + "display": "output" + }, + "old_prompt": { + "label": "Old Prompt", + "type": "custom", + "display": "output" + }, + "doc": { + "label": "Doc", + "type": "string", + "display": "output" + } + }, + "input_names": ["prompt"], + "model_input_names": [], + "output_names": ["out_prompt", "old_prompt", "doc"], + "block_name": "custom", + "node_type": "custom" + } + } +} +``` + +### Understanding the Structure + +The `params` dict defines how each UI element renders. The `input_names`, `model_input_names`, and `output_names` lists map these UI elements to the underlying [`ModularPipelineBlocks`]'s I/O interface: + +| Mellon Config | ModularPipelineBlocks | +|---------------|----------------------| +| `input_names` | `inputs` property | +| `model_input_names` | `expected_components` property | +| `output_names` | `intermediate_outputs` property | + +In this example: `prompt` is the only input. There are no model components, and outputs include `out_prompt`, `old_prompt`, and `doc`. + +Now let's look at the `params` dict: + +- **`prompt`**: An input parameter with `display: "textarea"` (renders as a text input box), `label: "Prompt"` (shown in the UI), and `default: ""` (starts empty). The `type: "string"` field is important in Mellon because it determines which nodes can connect together - only matching types can be linked with "noodles". + +- **`out_prompt`**: The expanded prompt output. The `out_` prefix was automatically added because the input and output share the same name (`prompt`), avoiding naming conflicts in the config. It has `display: "output"` which renders as an output socket. + +- **`old_prompt`**: Has `type: "custom"` because we didn't specify metadata. This renders as a simple dot in the UI. Since we don't actually want to expose this in the UI, we can remove it. + +- **`doc`**: The documentation output, automatically added to all custom blocks. + +### Making Adjustments + +Remove `old_prompt` from both `params` and `output_names` because you won't need to use it. + +```json +{ + "label": "Gemini Prompt Expander", + "default_repo": "", + "default_dtype": "", + "node_params": { + "custom": { + "params": { + "prompt": { + "label": "Prompt", + "type": "string", + "display": "textarea", + "default": "" + }, + "out_prompt": { + "label": "Prompt", + "type": "string", + "display": "output" + }, + "doc": { + "label": "Doc", + "type": "string", + "display": "output" + } + }, + "input_names": ["prompt"], + "model_input_names": [], + "output_names": ["out_prompt", "doc"], + "block_name": "custom", + "node_type": "custom" + } + } +} +``` + +See the final config at [diffusers/gemini-prompt-expander-mellon](https://huggingface.co/diffusers/gemini-prompt-expander-mellon). \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md index 7d07c4b73434..83975200d664 100644 --- a/docs/source/en/modular_diffusers/overview.md +++ b/docs/source/en/modular_diffusers/overview.md @@ -24,7 +24,7 @@ The Modular Diffusers docs are organized as shown below. ## Quickstart -- A [quickstart](./quickstart) demonstrating how to implement an example workflow with Modular Diffusers. +- The [quickstart](./quickstart) shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it. ## ModularPipelineBlocks @@ -33,9 +33,14 @@ The Modular Diffusers docs are organized as shown below. - [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together. - [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`]. - [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`]. +- [Building Custom Blocks](./custom_blocks) shows you how to create your own custom blocks and share them on the Hub. ## ModularPipeline - [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`]. - [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines. -- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline. \ No newline at end of file +- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline. + +## Mellon Integration + +- [Using Custom Blocks with Mellon](./mellon) shows you how to make your custom blocks work with [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. \ No newline at end of file diff --git a/docs/source/en/modular_diffusers/quickstart.md b/docs/source/en/modular_diffusers/quickstart.md index 32d14d84e243..5a455f0b3093 100644 --- a/docs/source/en/modular_diffusers/quickstart.md +++ b/docs/source/en/modular_diffusers/quickstart.md @@ -12,333 +12,248 @@ specific language governing permissions and limitations under the License. # Quickstart -Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface developers can use. +Modular Diffusers is a framework for quickly building flexible and customizable pipelines. These pipelines can go beyond what standard `DiffusionPipeline`s can do. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface for running generation tasks. -This doc will show you how to implement a [Differential Diffusion](https://differential-diffusion.github.io/) pipeline with the modular framework. +This guide shows you how to run a modular pipeline, understand its structure, and customize it by modifying the blocks that compose it. -## ModularPipelineBlocks +## Run a pipeline -[`ModularPipelineBlocks`] are *definitions* that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks. +[`ModularPipeline`] is the main interface for loading, running, and managing modular pipelines. +```py +import torch +from diffusers import ModularPipeline, ComponentsManager -- [`ModularPipelineBlocks`] is the most basic block for a single step. -- [`SequentialPipelineBlocks`] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block. -- [`LoopSequentialPipelineBlocks`] is a multi-block that runs iteratively and is designed for iterative workflows. -- [`AutoPipelineBlocks`] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline. +# Use ComponentsManager to enable auto CPU offloading for memory efficiency +manager = ComponentsManager() +manager.enable_auto_cpu_offload(device="cuda:0") -[Differential Diffusion](https://differential-diffusion.github.io/) is an image-to-image workflow. Start with the `IMAGE2IMAGE_BLOCKS` preset, a collection of `ModularPipelineBlocks` for image-to-image generation. +pipe = ModularPipeline.from_pretrained("Qwen/Qwen-Image", components_manager=manager) +pipe.load_components(torch_dtype=torch.bfloat16) -```py -from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS -IMAGE2IMAGE_BLOCKS = InsertableDict([ - ("text_encoder", StableDiffusionXLTextEncoderStep), - ("image_encoder", StableDiffusionXLVaeEncoderStep), - ("input", StableDiffusionXLInputStep), - ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep), - ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep), - ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep), - ("denoise", StableDiffusionXLDenoiseStep), - ("decode", StableDiffusionXLDecodeStep) -]) +image = pipe( + prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney", +).images[0] +image ``` -## Pipeline and block states - -Modular Diffusers uses *state* to communicate data between blocks. There are two types of states. +[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded. -- [`PipelineState`] is a global state that can be used to track all inputs and outputs across all blocks. -- [`BlockState`] is a local view of relevant variables from [`PipelineState`] for an individual block. +> [!TIP] +> [`ComponentsManager`] with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide. -## Customizing blocks +Learn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides. -[Differential Diffusion](https://differential-diffusion.github.io/) differs from standard image-to-image in its `prepare_latents` and `denoise` blocks. All the other blocks can be reused, but you'll need to modify these two. +## Understand the structure -Create placeholder `ModularPipelineBlocks` for `prepare_latents` and `denoise` by copying and modifying the existing ones. - -Print the `denoise` block to see that it is composed of [`LoopSequentialPipelineBlocks`] with three sub-blocks, `before_denoiser`, `denoiser`, and `after_denoiser`. Only the `before_denoiser` sub-block needs to be modified to prepare the latent input for the denoiser based on the change map. +A [`ModularPipeline`] has two parts: +- **State**: the loaded components (models, schedulers, processors) and configuration +- **Definition**: the [`ModularPipelineBlocks`] that specify inputs, outputs, expected components and computation logic +The blocks define *what* the pipeline does. Access them through `pipe.blocks`. ```py -denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]() -print(denoise_blocks) +print(pipe.blocks) ``` - -Replace the `StableDiffusionXLLoopBeforeDenoiser` sub-block with the new `SDXLDiffDiffLoopBeforeDenoiser` block. - -```py -# Copy existing blocks as placeholders -class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks): - """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later""" - # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep - -class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] ``` - -### prepare_latents - -The `prepare_latents` block requires the following changes. - -- a processor to process the change map -- a new `inputs` to accept the user-provided change map, `timestep` for precomputing all the latents and `num_inference_steps` to create the mask for updating the image regions -- update the computation in the `__call__` method for processing the change map and creating the masks, and storing it in the [`BlockState`] - -```diff -class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks): - @property - def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("vae", AutoencoderKL), - ComponentSpec("scheduler", EulerDiscreteScheduler), -+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True})) - ] - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("generator"), -+ InputParam("diffdiff_map", required=True), -- InputParam("latent_timestep", required=True, type_hint=torch.Tensor), -+ InputParam("timesteps", type_hint=torch.Tensor), -+ InputParam("num_inference_steps", type_hint=int), - ] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ -+ OutputParam("original_latents", type_hint=torch.Tensor), -+ OutputParam("diffdiff_masks", type_hint=torch.Tensor), - ] - def __call__(self, components, state: PipelineState): - # ... existing logic ... -+ # Process change map and create masks -+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width) -+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps -+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0)) -+ block_state.original_latents = block_state.latents -``` - -### denoise - -The `before_denoiser` sub-block requires the following changes. - -- a new `inputs` to accept a `denoising_start` parameter, `original_latents` and `diffdiff_masks` from the `prepare_latents` block -- update the computation in the `__call__` method for applying Differential Diffusion - -```diff -class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks): - @property - def description(self) -> str: - return ( - "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser" - ) - - @property - def inputs(self) -> List[str]: - return [ - InputParam("latents", required=True, type_hint=torch.Tensor), -+ InputParam("denoising_start"), -+ InputParam("original_latents", type_hint=torch.Tensor), -+ InputParam("diffdiff_masks", type_hint=torch.Tensor), - ] - - def __call__(self, components, block_state, i, t): -+ # Apply differential diffusion logic -+ if i == 0 and block_state.denoising_start is None: -+ block_state.latents = block_state.original_latents[:1] -+ else: -+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1) -+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask) - - # ... rest of existing logic ... +QwenImageAutoBlocks( + Class: SequentialPipelineBlocks + + Description: Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage. + + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `prompt`, `image` + - `inpainting`: requires `prompt`, `mask_image`, `image` + - `controlnet_text2image`: requires `prompt`, `control_image` + ... + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + vae (`AutoencoderKLQwenImage`) + transformer (`QwenImageTransformer2DModel`) + ... + + Sub-Blocks: + [0] text_encoder (QwenImageAutoTextEncoderStep) + [1] vae_encoder (QwenImageAutoVaeEncoderStep) + [2] controlnet_vae_encoder (QwenImageOptionalControlNetVaeEncoderStep) + [3] denoise (QwenImageAutoCoreDenoiseStep) + [4] decode (QwenImageAutoDecodeStep) +) ``` -## Assembling the blocks - -You should have all the blocks you need at this point to create a [`ModularPipeline`]. +The output returns: +- The supported workflows (text2image, image2image, inpainting, etc.) +- The Sub-Blocks it's composed of (text_encoder, vae_encoder, denoise, decode) -Copy the existing `IMAGE2IMAGE_BLOCKS` preset and for the `set_timesteps` block, use the `set_timesteps` from the `TEXT2IMAGE_BLOCKS` because Differential Diffusion doesn't require a `strength` parameter. - -Set the `prepare_latents` and `denoise` blocks to the `SDXLDiffDiffPrepareLatentsStep` and `SDXLDiffDiffDenoiseStep` blocks you just modified. - -Call [`SequentialPipelineBlocks.from_blocks_dict`] on the blocks to create a `SequentialPipelineBlocks`. +### Workflows +`QwenImageAutoBlocks` is a [`ConditionalPipelineBlocks`], so this pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Let's see this in action with an example. ```py -DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() -DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] -DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep -DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep +from diffusers.utils import load_image -dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS) -print(dd_blocks) -``` +input_image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true") -## ModularPipeline - -Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`]. - -It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization. +image = pipe( + prompt="cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney", + image=input_image, +).images[0] +``` +Use `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `"image2image"`, `"inpainting"`, `"controlnet_text2image"`) to get only the blocks relevant to that workflow. ```py -from diffusers.modular_pipelines import ComponentsManager - -components = ComponentManager() - -dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff") -dd_pipeline.load_componenets(torch_dtype=torch.float16) -dd_pipeline.to("cuda") +img2img_blocks = pipe.blocks.get_workflow("image2image") ``` -## Adding workflows +Conditional blocks are convenient for users, but their conditional logic adds complexity when customizing or debugging. Extracting a workflow gives you the specific blocks relevant to your workflow, making it easier to work with. Learn more in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide. -Other workflows can be added to the [`ModularPipeline`] to support additional features without rewriting the entire pipeline from scratch. +### Sub-blocks -This section demonstrates how to add an IP-Adapter or ControlNet. +Blocks can contain other blocks. `pipe.blocks` gives you the top-level block definition (here, `QwenImageAutoBlocks`), while `sub_blocks` lets you access the smaller blocks inside it. -### IP-Adapter +`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`. Access them through the `sub_blocks` property. -Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline. +The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components. +```py +vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"] +print(vae_encoder_block.doc) +``` +This block can be converted to a pipeline so that it can run on its own with [`~ModularPipelineBlocks.init_pipeline`]. ```py -from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep +vae_encoder_pipe = vae_encoder_block.init_pipeline() -ip_adapter_block = StableDiffusionXLAutoIPAdapterStep() +# Reuse the VAE we already loaded, we can reuse it with update_components() method +vae_encoder_pipe.update_components(vae=pipe.vae) + +# Run just this block +image_latents = vae_encoder_pipe(image=input_image).image_latents +print(image_latents.shape) ``` -Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`]. The example below inserts the `ip_adapter_block` at position `0`. Print the pipeline to see that the `ip_adapter_block` is added and it requires an `ip_adapter_image`. This also added two components to the pipeline, the `image_encoder` and `feature_extractor`. +It reuses the VAE from our original pipeline instead of reloading it, keeping memory usage efficient. Learn more in the [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guide. -```py -dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0) -``` +Since blocks are composable, you can modify the pipeline's definition by adding, removing, or swapping blocks to create new workflows. In the next section, we'll add a canny edge detection block to a ControlNet pipeline, so you can pass a regular image instead of a pre-processed canny edge map. -Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline. +## Compose new workflows +Let's add a canny edge detection block to a ControlNet pipeline. First, load a pre-built canny block from the Hub (see [Building Custom Blocks](https://huggingface.co/docs/diffusers/modular_diffusers/custom_blocks) to create your own). ```py -dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") -dd_pipeline.load_components(torch_dtype=torch.float16) -dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") -dd_pipeline.loader.set_ip_adapter_scale(0.6) -dd_pipeline = dd_pipeline.to(device) - -ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg") -image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") -mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +from diffusers.modular_pipelines import ModularPipelineBlocks -prompt = "a green pear" -negative_prompt = "blurry" -generator = torch.Generator(device=device).manual_seed(42) +# Load a canny block from the Hub +canny_block = ModularPipelineBlocks.from_pretrained( + "diffusers-internal-dev/canny-filtering", + trust_remote_code=True, +) -image = dd_pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=25, - generator=generator, - ip_adapter_image=ip_adapter_image, - diffdiff_map=mask, - image=image, - output="images" -)[0] +print(canny_block.doc) +``` +``` +class CannyBlock + + Inputs: + image (`Union[Image, ndarray]`): + Image to compute canny filter on + low_threshold (`int`, *optional*, defaults to 50): + Low threshold for the canny filter. + high_threshold (`int`, *optional*, defaults to 200): + High threshold for the canny filter. + ... + + Outputs: + control_image (`PIL.Image`): + Canny map for input image ``` -### ControlNet - -Stable Diffusion XL already has a preset ControlNet block that can readily be used. - +UUse `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`]. ```py -from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep - -control_input_block = StableDiffusionXLAutoControlNetInputStep() +# Get the controlnet workflow that we want to work with +blocks = pipe.blocks.get_workflow("controlnet_text2image") +print(blocks.doc) +``` +``` +class SequentialPipelineBlocks + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + control_image (`Image`): + Control image for ControlNet conditioning. + ... ``` -However, it requires modifying the `denoise` block because that's where the ControlNet injects the control information into the UNet. - -Modify the `denoise` block by replacing the `StableDiffusionXLLoopDenoiser` sub-block with the `StableDiffusionXLControlNetLoopDenoiser`. +The extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) - a multi-block type where blocks run one after another and data flows linearly from one block to the next. Each block's `intermediate_outputs` become available as `inputs` to subsequent blocks. +Currently this workflow requires `control_image` as input. Let's insert the canny block at the beginning so the pipeline accepts a regular image instead. ```py -class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper): - block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] +# Insert canny at the beginning +blocks.sub_blocks.insert("canny", canny_block, 0) -controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep() +# Check the updated structure: CannyBlock is now listed as first sub-block +print(blocks) +# Check the updated doc +print(blocks.doc) +``` +``` +class SequentialPipelineBlocks + + Inputs: + image (`Union[Image, ndarray]`): + Image to compute canny filter on + low_threshold (`int`, *optional*, defaults to 50): + Low threshold for the canny filter. + high_threshold (`int`, *optional*, defaults to 200): + High threshold for the canny filter. + prompt (`str`): + The prompt or prompts to guide image generation. + ... ``` -Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it. +Now the pipeline takes `image` as input instead of `control_image`. Because blocks in a sequence share data automatically, the canny block's output (`control_image`) flows to the denoise block that needs it, and the canny block's input (`image`) becomes a pipeline input since no earlier block provides it. +Create a pipeline from the modified blocks and load a ControlNet model. ```py -dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7) -dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block - -dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") -dd_pipeline.load_components(torch_dtype=torch.float16) -dd_pipeline = dd_pipeline.to(device) - -control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg") -image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true") -mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true") +pipeline = blocks.init_pipeline("Qwen/Qwen-Image", components_manager=manager) -prompt = "a green pear" -negative_prompt = "blurry" -generator = torch.Generator(device=device).manual_seed(42) +pipeline.load_components(torch_dtype=torch.bfloat16) -image = dd_pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=25, - generator=generator, - control_image=control_image, - controlnet_conditioning_scale=0.5, - diffdiff_map=mask, - image=image, - output="images" -)[0] +# Load the ControlNet model +controlnet_spec = pipeline.get_component_spec("controlnet") +controlnet_spec.pretrained_model_name_or_path = "InstantX/Qwen-Image-ControlNet-Union" +controlnet = controlnet_spec.load(torch_dtype=torch.bfloat16) +pipeline.update_components(controlnet=controlnet) ``` -### AutoPipelineBlocks - -The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single [`ModularPipeline`] by using [`AutoPipelineBlocks`]. This allows automatically selecting which sub-blocks to run based on the inputs like `control_image` or `ip_adapter_image`. If none of these inputs are passed, then it defaults to the Differential Diffusion. - -Use `block_trigger_inputs` to only run the `SDXLDiffDiffControlNetDenoiseStep` block if a `control_image` input is provided. Otherwise, the `SDXLDiffDiffDenoiseStep` is used. - +Now run the pipeline - the canny block preprocesses the image for ControlNet. ```py -class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep] - block_names = ["controlnet_denoise", "denoise"] - block_trigger_inputs = ["controlnet_cond", None] -``` +from diffusers.utils import load_image -Add the `ip_adapter` and `controlnet_input` blocks. +prompt = "cat wizard with red hat, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney" +image = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/cute_cat.png?raw=true") -```py -DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy() -DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep -DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"] -DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep -DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0) -DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7) +output = pipeline( + prompt=prompt, + image=image, +).images[0] +output ``` -Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipelineBlocks`] and create a [`ModularPipeline`] and load in the model components to run. - -```py -dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS) -dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff") -dd_pipeline.load_components(torch_dtype=torch.float16) -``` +## Next steps -## Share + + -Add your [`ModularPipeline`] to the Hub with [`~ModularPipeline.save_pretrained`] and set `push_to_hub` argument to `True`. +Learn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide. -```py -dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True) -``` + + -Other users can load the [`ModularPipeline`] with [`~ModularPipeline.from_pretrained`]. +Use [`ComponentsManager`](./components_manager) to share models across multiple pipelines and manage memory efficiently. -```py -import torch -from diffusers.modular_pipelines import ModularPipeline, ComponentsManager + + -components = ComponentsManager() +Connect modular pipelines to [Mellon](https://github.com/cubiq/Mellon), a visual node-based interface for building workflows. Custom blocks built with Modular Diffusers work out of the box with Mellon - no UI code required. Read more in the Mellon guide. -diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff") -diffdiff_pipeline.load_components(torch_dtype=torch.float16) -``` + + \ No newline at end of file diff --git a/docs/source/en/optimization/cache.md b/docs/source/en/optimization/cache.md index 881529b27ff1..4eccd70cb304 100644 --- a/docs/source/en/optimization/cache.md +++ b/docs/source/en/optimization/cache.md @@ -66,4 +66,102 @@ config = FasterCacheConfig( tensor_format="BFCHW", ) pipeline.transformer.enable_cache(config) -``` \ No newline at end of file +``` + +## FirstBlockCache + +[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output. + +```py +import torch +from diffusers import DiffusionPipeline +from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig + +pipeline = DiffusionPipeline.from_pretrained( + "Qwen/Qwen-Image", torch_dtype=torch.bfloat16 +) +apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2)) +``` +## TaylorSeer Cache + +[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations. + +This caching mechanism delivers strong results with minimal additional memory overhead. For detailed performance analysis, see [our findings here](https://github.com/huggingface/diffusers/pull/12648#issuecomment-3610615080). + +To enable TaylorSeer Cache, create a [`TaylorSeerCacheConfig`] and pass it to your pipeline's transformer: + +- `cache_interval`: Number of steps to reuse cached outputs before performing a full forward pass +- `disable_cache_before_step`: Initial steps that use full computations to gather data for approximations +- `max_order`: Approximation accuracy (in theory, higher values improve quality but increase memory usage but we recommend it should be set to `1`) + +```python +import torch +from diffusers import FluxPipeline, TaylorSeerCacheConfig + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, +).to("cuda") + +config = TaylorSeerCacheConfig( + cache_interval=5, + max_order=1, + disable_cache_before_step=10, + taylor_factors_dtype=torch.bfloat16, +) +pipe.transformer.enable_cache(config) +``` + +## MagCache + +[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual. + +MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler. + +### Usage + +To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**. + +1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console. +2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration. + +```python +import torch +from diffusers import FluxPipeline, MagCacheConfig + +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-schnell", + torch_dtype=torch.bfloat16 +).to("cuda") + +# 1. Calibration Step +# Run full inference to measure model behavior. +calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4) +pipe.transformer.enable_cache(calib_config) + +# Run a prompt to trigger calibration +pipe("A cat playing chess", num_inference_steps=4) +# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]" + +# 2. Inference Step +# Apply the specific ratios obtained from calibration for optimized speed. +# Note: For Flux models, you can also import defaults: +# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS +mag_config = MagCacheConfig( + mag_ratios=[1.0, 1.37, 0.97, 0.87], + num_inference_steps=4 +) + +pipe.transformer.enable_cache(mag_config) + +image = pipe("A cat playing chess", num_inference_steps=4).images[0] +``` + +> [!NOTE] +> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps. + +> [!TIP] +> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional). + +> [!TIP] +> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification. diff --git a/docs/source/en/quantization/modelopt.md b/docs/source/en/quantization/modelopt.md index 06933d47c221..c7fca9d44491 100644 --- a/docs/source/en/quantization/modelopt.md +++ b/docs/source/en/quantization/modelopt.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # NVIDIA ModelOpt -[NVIDIA-ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed. +[NVIDIA-ModelOpt](https://github.com/NVIDIA/Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed. Before you begin, make sure you have nvidia_modelopt installed. @@ -57,7 +57,7 @@ image.save("output.png") > > The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration. > -> More details can be found [here](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples). +> More details can be found [here](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples). ## NVIDIAModelOptConfig @@ -86,7 +86,7 @@ The quantization methods supported are as follows: | **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`| -Refer to the [official modelopt documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available. +Refer to the [official modelopt documentation](https://nvidia.github.io/Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available. ## Serializing and Deserializing quantized models diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 18cc109e0785..aa415dbc36e2 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig( ) pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", - quantzation_config=pipeline_quant_config, + quantization_config=pipeline_quant_config, torch_dtype=torch.bfloat16, device_map="cuda" ) @@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig( ) pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", - quantzation_config=pipeline_quant_config, + quantization_config=pipeline_quant_config, torch_dtype=torch.bfloat16, device_map="cuda" ) @@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig( ) pipeline = DiffusionPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", - quantzation_config=pipeline_quant_config, + quantization_config=pipeline_quant_config, torch_dtype=torch.bfloat16, device_map="cuda" ) @@ -83,25 +83,6 @@ Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue- > [!TIP] > The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible. -## autoquant - -torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment. - -```py -import torch -from diffusers import DiffusionPipeline -from torchao.quantization import autoquant - -# Load the pipeline -pipeline = DiffusionPipeline.from_pretrained( - "black-forest-labs/FLUX.1-schnell", - torch_dtype=torch.bfloat16, - device_map="cuda" -) - -transformer = autoquant(pipeline.transformer) -``` - ## Supported quantization types torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7. diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md index f9756e1a67aa..bdaa2ae8ffff 100644 --- a/docs/source/en/training/distributed_inference.md +++ b/docs/source/en/training/distributed_inference.md @@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends. +Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible. + ### Ring Attention Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency. @@ -245,38 +247,58 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf ```py import torch -from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig - -try: - torch.distributed.init_process_group("nccl") - rank = torch.distributed.get_rank() - device = torch.device("cuda", rank % torch.cuda.device_count()) +from torch import distributed as dist +from diffusers import DiffusionPipeline, ContextParallelConfig + +def setup_distributed(): + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - - transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2)) - pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda") - pipeline.transformer.set_attention_backend("flash") + return device + +def main(): + device = setup_distributed() + world_size = dist.get_world_size() + + pipeline = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ).to(device) + pipeline.transformer.set_attention_backend("_native_cudnn") + + cp_config = ContextParallelConfig(ring_degree=world_size) + pipeline.transformer.enable_parallelism(config=cp_config) prompt = """ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain """ - + # Must specify generator so all ranks start with same latents (or pass your own) generator = torch.Generator().manual_seed(42) - image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0] - - if rank == 0: - image.save("output.png") - -except Exception as e: - print(f"An error occurred: {e}") - torch.distributed.breakpoint() - raise - -finally: - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() + image = pipeline( + prompt, + guidance_scale=3.5, + num_inference_steps=50, + generator=generator, + ).images[0] + + if dist.get_rank() == 0: + image.save(f"output.png") + + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +``` + +The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available. + +```shell +torchrun --nproc-per-node 2 above_script.py ``` ### Ulysses Attention @@ -288,5 +310,83 @@ finally: Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`]. ```py +# Depending on the number of GPUs available. pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2)) -``` \ No newline at end of file +``` + +### Unified Attention + +[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout. + +This hybrid approach leverages the strengths of both methods: +- **Ulysses Attention** efficiently parallelizes across attention heads +- **Ring Attention** handles very long sequences with minimal memory overhead +- Together, they enable 2D parallelization across both heads and sequence dimensions + +[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping). +Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)) +``` + +> [!TIP] +> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices). + +We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows: + +| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | +|--------------------|------------------|-------------|------------------| +| ulysses | 6670.789 | 7.50 | 33.85 | +| ring | 13076.492 | 3.82 | 56.02 | +| unified_balanced | 11068.705 | 4.52 | 33.85 | + +From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention. + + +### Ulysses Anything Attention + +The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use. + +[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`]. + +```py +pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True)) +``` + +> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency. + +We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows: + +| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)| +|--------------------|------------------|-------------|------------------|------------| +| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 | +| ring | 351.34 | 2.85 | 37.01 | 1024x1024 | +| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 | +| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 | +| ulysses | failed | failed | failed | 1008x1008 | +| ring | failed | failed | failed | 1008x1008 | +| unified_balanced | failed | failed | failed | 1008x1008 | +| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 | + +From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention. + +### parallel_config + +Pass `parallel_config` during model initialization to enable context parallelism. + +```py +CKPT_ID = "black-forest-labs/FLUX.1-dev" + +cp_config = ContextParallelConfig(ring_degree=2) +transformer = AutoModel.from_pretrained( + CKPT_ID, + subfolder="transformer", + torch_dtype=torch.bfloat16, + parallel_config=cp_config +) + +pipeline = DiffusionPipeline.from_pretrained( + CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16, +).to(device) +``` diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 5aa33190d4a0..05f2b1ee17f3 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -94,7 +94,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 924323753be6..8fba00afc39e 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -88,7 +88,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 3aad6b7b4994..8fb749d328c9 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -95,7 +95,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) @@ -1929,6 +1929,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip): if args.cache_latents: latents_cache = [] + # Store vae config before potential deletion + vae_scaling_factor = vae.config.scaling_factor for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( @@ -1940,6 +1942,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip): del vae if torch.cuda.is_available(): torch.cuda.empty_cache() + else: + vae_scaling_factor = vae.config.scaling_factor # Scheduler and math around the number of training steps. # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. @@ -2109,13 +2113,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = vae.encode(pixel_values).latent_dist.sample() if latents_mean is None and latents_std is None: - model_input = model_input * vae.config.scaling_factor + model_input = model_input * vae_scaling_factor if args.pretrained_vae_model_name_or_path is None: model_input = model_input.to(weight_dtype) else: latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype) latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype) - model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std + model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std model_input = model_input.to(dtype=weight_dtype) # Sample noise that we'll add to the latents diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index b4440e807e49..001934298abe 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) @@ -149,13 +149,13 @@ def get_args(): "--validation_prompt", type=str, default=None, - help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.", ) parser.add_argument( "--validation_images", type=str, default=None, - help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.", ) parser.add_argument( "--validation_prompt_separator", diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 9a1e5fd45c78..f6f2dc83a3f9 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -52,7 +52,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) @@ -140,7 +140,7 @@ def get_args(): "--validation_prompt", type=str, default=None, - help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.", ) parser.add_argument( "--validation_prompt_separator", diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py index ae12012a4c7e..6f06ed749635 100644 --- a/examples/cogview4-control/train_control_cogview4.py +++ b/examples/cogview4-control/train_control_cogview4.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py index 3bdaef79818d..cd4473264e41 100644 --- a/examples/community/marigold_depth_estimation.py +++ b/examples/community/marigold_depth_estimation.py @@ -43,7 +43,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") class MarigoldDepthOutput(BaseOutput): diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py index fb7a4cb5e472..bc6841525b49 100644 --- a/examples/community/pipeline_hunyuandit_differential_img2img.py +++ b/examples/community/pipeline_hunyuandit_differential_img2img.py @@ -21,8 +21,8 @@ BertModel, BertTokenizer, CLIPImageProcessor, - MT5Tokenizer, T5EncoderModel, + T5Tokenizer, ) from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback @@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -295,7 +295,7 @@ def __init__( feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, text_encoder_2=T5EncoderModel, - tokenizer_2=MT5Tokenizer, + tokenizer_2=T5Tokenizer, ): super().__init__() diff --git a/examples/community/pipeline_z_image_differential_img2img.py b/examples/community/pipeline_z_image_differential_img2img.py new file mode 100644 index 000000000000..8bde065c4013 --- /dev/null +++ b/examples/community/pipeline_z_image_differential_img2img.py @@ -0,0 +1,844 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.models.transformers import ZImageTransformer2DModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from pipeline_z_image_differential_img2img import ZImageDifferentialImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = ZImageDifferentialImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> init_image = load_image( + >>> "https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true", + >>> ) + + >>> mask = load_image( + >>> "https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true", + >>> ) + + >>> prompt = "painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art" + + >>> image = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask, + ... strength=0.75, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(41), + ... ).images[0] + >>> image.save("image.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageDifferentialImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + r""" + The ZImage pipeline for image-to-image generation. + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`PreTrainedModel`]): + A text encoder model to encode text prompts. + tokenizer ([`AutoTokenizer`]): + A tokenizer to tokenize text prompts. + transformer ([`ZImageTransformer2DModel`]): + A ZImage transformer model to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + vae_latent_channels=latent_channels, + do_normalize=False, + do_binarize=False, + do_convert_grayscale=True, + ) + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + # Encode the input image + image = image.to(device=device, dtype=dtype) + if image.shape[1] != num_channels_latents: + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor) + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + image_latents = image + + # Handle batch size expansion + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + # Add noise using flow matching scale_noise + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + strength: float = 0.6, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for image-to-image generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. Black pixels in the mask + are repainted while white pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + strength (`float`, *optional*, defaults to 0.6): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. If not provided, uses the input image height. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. If not provided, uses the input image width. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + # 1. Check inputs and validate strength + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image) + init_image = init_image.to(dtype=torch.float32) + + # Get dimensions from the preprocessed image if not specified + if height is None: + height = init_image.shape[-2] + if width is None: + width = init_image.shape[-1] + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + + # Calculate latent dimensions for image_seq_len + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + + # 6. Adjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline " + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(actual_batch_size) + + # 7. Prepare latents from image + latents, noise, original_image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + actual_batch_size, + num_channels_latents, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + latents, + ) + resize_mode = "default" + crops_coords = None + + # start diff diff preparation + original_mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + masked_image = init_image * original_mask + original_mask, _ = self.prepare_mask_latents( + original_mask, + masked_image, + batch_size, + num_images_per_prompt, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + ) + mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps + mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device) + masks = original_mask > mask_thresholds + # end diff diff preparation + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 8. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + # start diff diff + image_latent = original_image_latents + latents_dtype = latents.dtype + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + image_latent = self.scheduler.scale_noise( + original_image_latents, torch.tensor([noise_timestep]), noise + ) + + mask = masks[i].to(latents_dtype) + latents = image_latent * mask + latents * (1 - mask) + # end diff diff + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index fb3ad0118360..26a3ecc87935 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -74,7 +74,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index bb35649b51d6..ef50e8eb2da4 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -67,7 +67,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index 99ad07d240b8..a3302d7147b9 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -80,7 +80,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 9f38b8c9b67f..79bc706bcca3 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -73,7 +73,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 3c51dd25c2b2..d6b2dd895766 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -79,7 +79,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 7d85878e6680..198501da725e 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index d1e1c8efd8cb..588f6b1f4ca0 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 6d786f632026..5d54e34eaa06 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -66,7 +66,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 1d6fc57640c3..1d130a38c97e 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -63,7 +63,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index d9e2a712c48f..b853a32c4483 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -62,7 +62,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index c105a3786ea7..5922b7443c10 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -64,7 +64,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 1d1777811387..ad5d61f1f9e2 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -1,14 +1,22 @@ -# DreamBooth training example for FLUX.2 [dev] +# DreamBooth training example for FLUX.2 [dev] and FLUX 2 [klein] [DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept. +[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. + +The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scripts shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://huggingface.co/black-forest-labs/FLUX.2-dev) and [FLUX 2 [klein]](https://huggingface.co/black-forest-labs/FLUX.2-klein). -The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2). +> [!NOTE] +> **Model Variants** +> +> We support two FLUX model families: +> - **FLUX.2 [dev]**: The full-size model using Mistral Small 3.1 as the text encoder. Very capable but memory intensive. +> - **FLUX 2 [klein]**: Available in 4B and 9B parameter variants, using Qwen VL as the text encoder. Much more memory efficient and suitable for consumer hardware. > [!NOTE] > **Memory consumption** > -> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - -> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training. +> FLUX.2 [dev] can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements - +> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. FLUX 2 [klein] models (4B and 9B) are significantly more memory efficient alternatives. Below we provide some tips and tricks to reduce memory consumption during training. > For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX: > 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md) @@ -17,7 +25,7 @@ The `train_dreambooth_lora_flux2.py` script shows how to implement the training > [!NOTE] > **Gated model** > -> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: +> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you've accepted the gate. Use the command below to log in: ```bash hf auth login @@ -88,20 +96,32 @@ snapshot_download( This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform. -As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training: +As mentioned, Flux2 LoRA training is *very* memory intensive (especially for FLUX.2 [dev]). Here are memory optimizations we can use (some still experimental) for a more memory efficient training: ## Memory Optimizations > [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption. > However some techniques may be mutually exclusive so be sure to check before launching a training run. + ### Remote Text Encoder -Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API. +FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API. This way, the text encoder model is not loaded into memory during training. + +> [!IMPORTANT] +> **Remote text encoder is only supported for FLUX.2 [dev]**. FLUX 2 [klein] models use the Qwen VL text encoder and do not support remote text encoding. + > [!NOTE] > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. + +### FSDP Text Encoder +FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. +This way, it distributes the memory cost across multiple nodes. + ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. + ### Latent Caching Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`. + ### QLoRA: Low Precision Training with Quantization Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags: - **FP8 training** with `torchao`: @@ -111,22 +131,29 @@ enable FP8 training by passing `--do_fp8_training`. - **NF4 training** with `bitsandbytes`: Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing: `--bnb_quantization_config_path` to enable 4-bit NF4 quantization. + ### Gradient Checkpointing and Accumulation * `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass. by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs. * with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass. Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass. + ### 8-bit-Adam Optimizer When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so. + ### Image Resolution An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this. Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions. + ### Precision of saved LoRA layers By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well. This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`. +## Training Examples +### FLUX.2 [dev] Training +To perform DreamBooth with LoRA on FLUX.2 [dev], run: ```bash export MODEL_NAME="black-forest-labs/FLUX.2-dev" export INSTANCE_DIR="dog" @@ -158,19 +185,104 @@ accelerate launch train_dreambooth_lora_flux2.py \ --push_to_hub ``` -To better track our training experiments, we're using the following flags in the command above: +### FLUX 2 [klein] Training -* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. -* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. +FLUX 2 [klein] models are more memory efficient alternatives available in 4B and 9B parameter variants. They use the Qwen VL text encoder instead of Mistral Small 3.1. > [!NOTE] -> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases. +> The `--remote_text_encoder` flag is **not supported** for FLUX 2 [klein] models. The Qwen VL text encoder must be loaded locally, but offloading is still supported. -## LoRA + DreamBooth +**FLUX 2 [klein] 4B:** -[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters. +```bash +export MODEL_NAME="black-forest-labs/FLUX.2-klein-4B" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux2-klein-4b" -Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +accelerate launch train_dreambooth_lora_flux2_klein.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --do_fp8_training \ + --gradient_checkpointing \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --use_8bit_adam \ + --gradient_accumulation_steps=4 \ + --optimizer="adamW" \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +**FLUX 2 [klein] 9B:** + +```bash +export MODEL_NAME="black-forest-labs/FLUX.2-klein-9B" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="trained-flux2-klein-9b" + +accelerate launch train_dreambooth_lora_flux2_klein.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --do_fp8_training \ + --gradient_checkpointing \ + --cache_latents \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --use_8bit_adam \ + --gradient_accumulation_steps=4 \ + --optimizer="adamW" \ + --learning_rate=1e-4 \ + --report_to="wandb" \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login ` before training if you haven't done it before. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +> [!NOTE] +> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. Note that this will use more resources and may slow down the training in some cases. + +### FSDP on the transformer +By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to: + +```shell +distributed_type: FSDP +fsdp_config: + fsdp_version: 2 + fsdp_offload_params: false + fsdp_sharding_strategy: HYBRID_SHARD + fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock + fsdp_forward_prefetch: true + fsdp_sync_module_states: false + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_use_orig_params: false + fsdp_activation_checkpointing: true + fsdp_reshard_after_forward: true + fsdp_cpu_ram_efficient_loading: false +``` ### Prodigy Optimizer Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence. @@ -183,8 +295,6 @@ to use prodigy, first make sure to install the prodigyopt library: `pip install > [!TIP] > When using prodigy it's generally good practice to set- `--learning_rate=1.0` -To perform DreamBooth with LoRA, run: - ```bash export MODEL_NAME="black-forest-labs/FLUX.2-dev" export INSTANCE_DIR="dog" @@ -248,13 +358,10 @@ the exact modules for LoRA training. Here are some examples of target modules yo > keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. - ## Training Image-to-Image Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too. -**important** - **Important** To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment: @@ -311,5 +418,6 @@ we've added aspect ratio bucketing support which allows training on images with To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as: `--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672" -` -Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 + + +Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 \ No newline at end of file diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md index 7972434b5e6f..8bddacf975d8 100644 --- a/examples/dreambooth/README_sana.md +++ b/examples/dreambooth/README_sana.md @@ -111,6 +111,25 @@ To better track our training experiments, we're using the following flags in the ## Notes +### LoRA Rank and Alpha +Two key LoRA hyperparameters are LoRA rank and LoRA alpha. +- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters). +- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank. +- lora_alpha vs. rank: +This ratio dictates the LoRA's effective strength: +lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16) +lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16) +lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16) + +> [!TIP] +> A common starting point is to set `lora_alpha` equal to `rank`. +> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16) +> to give the LoRA updates more influence without increasing parameter count. +> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank` +> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case. + +### Additional CLI arguments + Additionally, we welcome you to explore the following CLI arguments: * `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only. diff --git a/examples/dreambooth/test_dreambooth_lora_flux2_klein.py b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py new file mode 100644 index 000000000000..0e5506e1a3eb --- /dev/null +++ b/examples/dreambooth/test_dreambooth_lora_flux2_klein.py @@ -0,0 +1,262 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import sys +import tempfile + +import safetensors + +from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFlux2Klein(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "dog" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein" + script_path = "examples/dreambooth/train_dreambooth_lora_flux2_klein.py" + transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj" + + def test_dreambooth_lora_flux2(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict + starts_with_transformer = all( + key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --max_sequence_length 8 + --checkpointing_steps=2 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + --max_sequence_length 8 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + --max_sequence_length 8 + --text_encoder_out_layers 1 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) + + def test_dreambooth_lora_with_metadata(self): + # Use a `lora_alpha` that is different from `rank`. + lora_alpha = 8 + rank = 4 + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_alpha={lora_alpha} + --rank={rank} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --max_sequence_length 8 + --text_encoder_out_layers 1 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + self.assertTrue(os.path.isfile(state_dict_file)) + + # Check if the metadata was properly serialized. + with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f: + metadata = f.metadata() or {} + + metadata.pop("format", None) + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + if raw: + raw = json.loads(raw) + + loaded_lora_alpha = raw["transformer.lora_alpha"] + self.assertTrue(loaded_lora_alpha == lora_alpha) + loaded_lora_rank = raw["transformer.r"] + self.assertTrue(loaded_lora_rank == rank) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 503e2ae1d47d..2e66e1f724e7 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -64,7 +64,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 6c09f0a84cf9..e68d9df5e424 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -35,7 +35,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") # Cache compiled models across invocations of this script. cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index c24d16c6005a..468f6fce3ecb 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -80,7 +80,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index b105aa55361a..2d15684f9107 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -75,7 +75,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 3b6ab814f278..8ae2ddd9796b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -92,7 +92,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 733abe16d2eb..317ed2c2b2e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -44,6 +44,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -75,13 +76,16 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -93,11 +97,14 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) @@ -722,6 +729,7 @@ def parse_args(input_args=None): ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -1219,7 +1227,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1263,17 +1275,42 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) - # make sure to pop weight so that corresponding model is not saved again + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: weights.pop() Flux2Pipeline.save_lora_weights( @@ -1285,13 +1322,20 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1507,6 +1551,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1536,6 +1595,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1777,7 +1838,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1836,15 +1897,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 32bce9531b71..16a3863c881d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -43,6 +43,7 @@ import shutil from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -74,13 +75,16 @@ from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, find_nearest_bucket, free_memory, + get_fsdp_kwargs_from_accelerator, offload_models, parse_buckets_string, + wrap_with_fsdp, ) from diffusers.utils import ( check_min_version, @@ -93,11 +97,14 @@ from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) @@ -120,7 +127,7 @@ def save_model_card( ) model_description = f""" -# Flux DreamBooth LoRA - {repo_id} +# Flux.2 DreamBooth LoRA - {repo_id} @@ -339,7 +346,7 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, + required=False, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( @@ -691,6 +698,7 @@ def parse_args(input_args=None): parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") if input_args is not None: args = parser.parse_args(input_args) @@ -827,15 +835,28 @@ def __init__( dest_image = self.cond_images[i] image_width, image_height = dest_image.size if image_width * image_height > 1024 * 1024: - dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024) + dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024) image_width, image_height = dest_image.size multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp! image_width = (image_width // multiple_of) * multiple_of image_height = (image_height // multiple_of) * multiple_of - dest_image = Flux2ImageProcessor.image_processor.preprocess( + image_processor = Flux2ImageProcessor() + dest_image = image_processor.preprocess( dest_image, height=image_height, width=image_width, resize_mode="crop" ) + # Convert back to PIL + dest_image = dest_image.squeeze(0) + if dest_image.min() < 0: + dest_image = (dest_image + 1) / 2 + dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu() + + if dest_image.shape[0] == 1: + # Gray scale image + dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L") + else: + # RGB scale image: (C, H, W) -> (H, W, C) + dest_image = TF.to_pil_image(dest_image) dest_image = exif_transpose(dest_image) if not dest_image.mode == "RGB": @@ -1156,7 +1177,11 @@ def main(args): if args.bnb_quantization_config_path is not None else {"device": accelerator.device, "dtype": weight_dtype} ) - transformer.to(**transformer_to_kwargs) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + if args.do_fp8_training: convert_to_float8_training( transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) @@ -1200,17 +1225,42 @@ def unwrap_model(model): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) - # make sure to pop weight so that corresponding model is not saved again + # make sure to pop weight so that corresponding model is not saved again + if weights: weights.pop() Flux2Pipeline.save_lora_weights( @@ -1222,13 +1272,20 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1419,9 +1476,9 @@ def _encode_single(prompt: str): args.instance_prompt, text_encoding_pipeline ) - validation_image = load_image(args.validation_image_path).convert("RGB") - validation_kwargs = {"image": validation_image} if args.validation_prompt is not None: + validation_image = load_image(args.validation_image_path).convert("RGB") + validation_kwargs = {"image": validation_image} if args.remote_text_encoder: validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt) else: @@ -1430,6 +1487,21 @@ def _encode_single(prompt: str): args.validation_prompt, text_encoding_pipeline ) + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. @@ -1461,6 +1533,8 @@ def _encode_single(prompt: str): if train_dataset.custom_instance_prompts: if args.remote_text_encoder: prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"]) + elif args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) else: with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) @@ -1621,9 +1695,13 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device) - cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to( + cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] + cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to( device=cond_model_input.device ) + cond_model_input_ids = cond_model_input_ids.view( + cond_model_input.shape[0], -1, model_input_ids.shape[-1] + ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) @@ -1650,6 +1728,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input) packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input) + orig_input_shape = packed_noisy_model_input.shape + orig_input_ids_shape = model_input_ids.shape + # concatenate the model inputs with the cond inputs packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) @@ -1668,7 +1749,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): img_ids=model_input_ids, # B, image_seq_len, 4 return_dict=False, )[0] - model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + model_pred = model_pred[:, : orig_input_shape[1], :] + model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :] model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids) @@ -1700,7 +1782,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1759,15 +1841,41 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Save the lora layers accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) if accelerator.is_main_process: modules_to_save = {} - transformer = unwrap_model(transformer) - if args.bnb_quantization_config_path is None: - if args.upcast_before_saving: - transformer.to(torch.float32) - else: - transformer = transformer.to(weight_dtype) - transformer_lora_layers = get_peft_model_state_dict(transformer) + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + modules_to_save["transformer"] = transformer Flux2Pipeline.save_lora_weights( diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py new file mode 100644 index 000000000000..278c25900a3a --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -0,0 +1,1942 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM + +import diffusers +from diffusers import ( + AutoencoderKLFlux2, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + Flux2KleinPipeline, + Flux2Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _collate_lora_metadata, + _to_cpu_contiguous, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + get_fsdp_kwargs_from_accelerator, + offload_models, + parse_buckets_string, + wrap_with_fsdp, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.37.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + quant_training=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux.2 [Klein] DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md). + +Quant training? {quant_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux2-klein", + "flux2-klein-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(dtype=torch_dtype) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "proj_out": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--text_encoder_out_layers", + type=int, + nargs="+", + default=[10, 20, 30], + help="Text encoder hidden layers to compute the final text embeddings.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + if args.do_fp8_training and args.bnb_quantization_config_path: + raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image = self.train_transform( + image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + + return image + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = Qwen2TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + accelerator.device + ) + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + text_encoder = Qwen3ForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + text_encoder.to(**to_kwargs) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + revision=args.revision, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None + if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + Flux2KleinPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not is_fsdp: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( + prompt=prompt, + max_sequence_length=args.max_sequence_length, + text_encoder_out_layers=args.text_encoder_out_layers, + ) + return prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + class_prompt_hidden_states, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoding_pipeline + ) + validation_embeddings = {} + if args.validation_prompt is not None: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + (validation_embeddings["prompt_embeds"], validation_embeddings["text_ids"]) = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + text_ids_cache = [] + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + if args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds_cache.append(prompt_embeds) + text_ids_cache.append(text_ids) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux2-klein-lora" + args_cp = vars(args).copy() + args_cp["text_encoder_out_layers"] = str(args_cp["text_encoder_out_layers"]) + accelerator.init_trackers(tracker_name, config=args_cp) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + text_ids = text_ids_cache[step] + else: + num_repeat_elements = len(prompts) + prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) + text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() + + model_input = Flux2KleinPipeline._patchify_latents(model_input) + model_input = (model_input - latents_bn_mean) / latents_bn_std + + model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # [B, C, H, W] -> [B, H*W, C] + packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.full([1], args.guidance_scale, device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, # (B, image_seq_len, C) + timestep=timesteps / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=model_input_ids, # B, image_seq_len, 4 + return_dict=False, + )[0] + model_pred = model_pred[:, : packed_noisy_model_input.size(1) :] + + model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or is_fsdp: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) + if accelerator.is_main_process: + modules_to_save = {} + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + modules_to_save["transformer"] = transformer + + Flux2KleinPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_embeddings, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + images = None + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + quant_training = None + if args.do_fp8_training: + quant_training = "FP8 TorchAO" + elif args.bnb_quantization_config_path: + quant_training = "BitsandBytes" + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + quant_training=quant_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py new file mode 100644 index 000000000000..28cbaf8f72e7 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -0,0 +1,1889 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import copy +import itertools +import json +import logging +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM + +import diffusers +from diffusers import ( + AutoencoderKLFlux2, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + Flux2KleinPipeline, + Flux2Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor +from diffusers.training_utils import ( + _collate_lora_metadata, + _to_cpu_contiguous, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + get_fsdp_kwargs_from_accelerator, + offload_models, + parse_buckets_string, + wrap_with_fsdp, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, + load_image, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.37.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + fp8_training=False, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Flux.2 [Klein] DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux2 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux2.md). + +FP8 training? {fp8_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.2", torch_dtype=torch.bfloat16).to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.2/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux2", + "flux2-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(dtype=torch_dtype) + pipeline.enable_model_cpu_offload() + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + image=pipeline_args["image"], + prompt_embeds=pipeline_args["prompt_embeds"], + negative_prompt_embeds=pipeline_args["negative_prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the output module + if fqn == "proj_out": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--cond_image_column", + type=str, + default=None, + help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=False, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + help="path to an image that is used during validation as the condition image to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.cond_image_column is None: + raise ValueError( + "you must provide --cond_image_column for image-to-image training. Otherwise please see Flux2 text-to-image training example." + ) + else: + assert args.image_column is not None + assert args.caption_column is not None + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.cond_image_column is not None and args.cond_image_column not in column_names: + raise ValueError( + f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + cond_images = None + cond_image_column = args.cond_image_column + if cond_image_column is not None: + cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))] + assert len(instance_images) == len(cond_images) + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + self.cond_images = [] + for i, img in enumerate(instance_images): + self.instance_images.extend(itertools.repeat(img, repeats)) + if args.dataset_name is not None and cond_images is not None: + self.cond_images.extend(itertools.repeat(cond_images[i], repeats)) + + self.pixel_values = [] + self.cond_pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + dest_image = None + if self.cond_images: # todo: take care of max area for buckets + dest_image = self.cond_images[i] + image_width, image_height = dest_image.size + if image_width * image_height > 1024 * 1024: + dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024) + image_width, image_height = dest_image.size + + multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp! + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + image_processor = Flux2ImageProcessor() + dest_image = image_processor.preprocess( + dest_image, height=image_height, width=image_width, resize_mode="crop" + ) + # Convert back to PIL + dest_image = dest_image.squeeze(0) + if dest_image.min() < 0: + dest_image = (dest_image + 1) / 2 + dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu() + + if dest_image.shape[0] == 1: + # Gray scale image + dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L") + else: + # RGB scale image: (C, H, W) -> (H, W, C) + dest_image = TF.to_pil_image(dest_image) + + dest_image = exif_transpose(dest_image) + if not dest_image.mode == "RGB": + dest_image = dest_image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image, dest_image = self.paired_transform( + image, + dest_image=dest_image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + if dest_image is not None: + self.cond_pixel_values.append((dest_image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.cond_pixel_values: + dest_image, _ = self.cond_pixel_values[index % self.num_instance_images] + example["cond_images"] = dest_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + return example + + def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + if dest_image is not None: + dest_image = resize(dest_image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + if dest_image is not None: + dest_image = crop(dest_image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + if dest_image is not None: + dest_image = TF.crop(dest_image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + if dest_image is not None: + dest_image = TF.hflip(dest_image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + if dest_image is not None: + dest_image = normalize(to_tensor(dest_image)) + + return (image, dest_image) if dest_image is not None else (image, None) + + +def collate_fn(examples): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + if any("cond_images" in example for example in examples): + cond_pixel_values = [example["cond_images"] for example in examples] + cond_pixel_values = torch.stack(cond_pixel_values) + cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() + batch.update({"cond_pixel_values": cond_pixel_values}) + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer = Qwen2TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + accelerator.device + ) + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + text_encoder = Qwen3ForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device + transformer_to_kwargs = ( + {"device": accelerator.device} + if args.bnb_quantization_config_path is not None + else {"device": accelerator.device, "dtype": weight_dtype} + ) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + text_encoder.to(**to_kwargs) + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=None, + revision=args.revision, + ) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None + if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + Flux2KleinPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not is_fsdp: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = Flux2KleinPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, text_encoding_pipeline): + with torch.no_grad(): + prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt( + prompt=prompt, max_sequence_length=args.max_sequence_length + ) + return prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_prompt_hidden_states, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoding_pipeline + ) + + if args.validation_prompt is not None: + validation_image = load_image(args.validation_image).convert("RGB") + validation_kwargs = {"image": validation_image} + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + validation_kwargs["prompt_embeds"], _text_ids = compute_text_embeddings( + args.validation_prompt, text_encoding_pipeline + ) + validation_kwargs["negative_prompt_embeds"], _text_ids = compute_text_embeddings( + "", text_encoding_pipeline + ) + + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + prompt_embeds = instance_prompt_hidden_states + text_ids = instance_text_ids + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + if precompute_latents: + prompt_embeds_cache = [] + text_ids_cache = [] + latents_cache = [] + cond_latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + batch["cond_pixel_values"] = batch["cond_pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + if args.fsdp_text_encoder: + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline) + prompt_embeds_cache.append(prompt_embeds) + text_ids_cache.append(text_ids) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux2-image2img-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds = prompt_embeds_cache[step] + text_ids = text_ids_cache[step] + else: + num_repeat_elements = len(prompts) + prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) + text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + cond_model_input = cond_latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) + + model_input = vae.encode(pixel_values).latent_dist.mode() + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() + + model_input = Flux2KleinPipeline._patchify_latents(model_input) + model_input = (model_input - latents_bn_mean) / latents_bn_std + + cond_model_input = Flux2KleinPipeline._patchify_latents(cond_model_input) + cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std + + model_input_ids = Flux2KleinPipeline._prepare_latent_ids(model_input).to(device=model_input.device) + cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])] + cond_model_input_ids = Flux2KleinPipeline._prepare_image_ids(cond_model_input_list).to( + device=cond_model_input.device + ) + cond_model_input_ids = cond_model_input_ids.view( + cond_model_input.shape[0], -1, model_input_ids.shape[-1] + ) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # [B, C, H, W] -> [B, H*W, C] + # concatenate the model inputs with the cond inputs + packed_noisy_model_input = Flux2KleinPipeline._pack_latents(noisy_model_input) + packed_cond_model_input = Flux2KleinPipeline._pack_latents(cond_model_input) + orig_input_shape = packed_noisy_model_input.shape + orig_input_ids_shape = model_input_ids.shape + + # concatenate the model inputs with the cond inputs + packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1) + model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.full([1], args.guidance_scale, device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, # (B, image_seq_len, C) + timestep=timesteps / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=model_input_ids, # B, image_seq_len, 4 + return_dict=False, + )[0] + # pruning the condition information + model_pred = model_pred[:, : orig_input_shape[1], :] + model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :] + + model_pred = Flux2KleinPipeline._unpack_latents_with_ids(model_pred, model_input_ids) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or is_fsdp: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=None, + tokenizer=None, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_kwargs, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) + if accelerator.is_main_process: + modules_to_save = {} + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + modules_to_save["transformer"] = transformer + + Flux2KleinPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = Flux2KleinPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=validation_kwargs, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + fp8_training=args.do_fp8_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index fc6df87768ca..1a6757810a80 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -92,7 +92,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 8cbc3a43fde3..3abc7afcad2c 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -75,7 +75,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py index 8bf489586356..a13c579718c7 100644 --- a/examples/dreambooth/train_dreambooth_lora_lumina2.py +++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py @@ -73,7 +73,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 56de160d6f29..33a1054effaf 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -93,7 +93,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) @@ -1467,7 +1467,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: num_repeat_elements = len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) - prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1) # Convert images to latent space if args.cache_latents: model_input = latents_cache[step].sample() @@ -1513,14 +1514,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): height=model_input.shape[3], width=model_input.shape[4], ) - print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}") model_pred = transformer( hidden_states=packed_noisy_model_input, encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, timestep=timesteps / 1000, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, )[0] model_pred = QwenImagePipeline._unpack_latents( diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 2b0c1ee6697d..0afc31cf8a9a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -91,7 +91,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index eef732c531d3..d6770c805d25 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -73,7 +73,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 1ffb73cee4a2..51bac5d59667 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -80,7 +80,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index d345ebb391e3..e43e3178202a 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -64,7 +64,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index fe47e074412d..1e3be74464be 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -55,7 +55,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 36320449bd50..3185f1b2ea6a 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -58,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 85b85aa2fabe..1bfe7aed30cb 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -58,7 +58,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index acf5d8dff054..9a5b23a8e623 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index a30e2559537e..158c3a6f0994 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -53,7 +53,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 57c92f3ae543..30094f54827f 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index 2a0ef7d6fbba..9c0a4c38504e 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index df7cffef9b2f..caa8d96ef3ec 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/research_projects/control_lora/README.md b/examples/research_projects/control_lora/README.md new file mode 100644 index 000000000000..49aa848e3e0b --- /dev/null +++ b/examples/research_projects/control_lora/README.md @@ -0,0 +1,41 @@ +# Control-LoRA inference example + +Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs. + +## Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +## Inference on SDXL + +[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) provides a set of Control-LoRA weights for SDXL. Here we use the `canny` condition to generate an image from a text prompt and a reference image. + +```bash +python control_lora.py +``` + +## Acknowledgements + +- [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) +- [comfyanonymous/ControlNet-v1-1_fp16_safetensors](https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors) +- [HighCWu/control-lora-v2](https://github.com/HighCWu/control-lora-v2) \ No newline at end of file diff --git a/examples/research_projects/control_lora/control_lora.py b/examples/research_projects/control_lora/control_lora.py new file mode 100644 index 000000000000..a0ad1981c71d --- /dev/null +++ b/examples/research_projects/control_lora/control_lora.py @@ -0,0 +1,58 @@ +import cv2 +import numpy as np +import torch +from PIL import Image + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, +) +from diffusers.utils import load_image, make_image_grid + + +pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" +lora_id = "stabilityai/control-lora" +lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" + +unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.bfloat16).to("cuda") +controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16) +controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config) + +prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" +negative_prompt = "low quality, bad quality, sketches" + +image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" +) + +controlnet_conditioning_scale = 1.0 # recommended for good generalization + +vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.bfloat16) +pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + pipe_id, + unet=unet, + controlnet=controlnet, + vae=vae, + torch_dtype=torch.bfloat16, + safety_checker=None, +).to("cuda") + +image = np.array(image) +image = cv2.Canny(image, 100, 200) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +image = Image.fromarray(image) + +images = pipe( + prompt, + negative_prompt=negative_prompt, + image=image, + controlnet_conditioning_scale=controlnet_conditioning_scale, + num_images_per_prompt=4, +).images + +final_image = [image] + images +grid = make_image_grid(final_image, 1, 5) +grid.save("hf-logo_canny.png") diff --git a/examples/research_projects/lpl/README.md b/examples/research_projects/lpl/README.md new file mode 100644 index 000000000000..a69fead50893 --- /dev/null +++ b/examples/research_projects/lpl/README.md @@ -0,0 +1,157 @@ +# Latent Perceptual Loss (LPL) for Stable Diffusion XL + +This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada. + +## Overview + +LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to: + +- Loss of fine details in generated images +- Inconsistent image quality +- Structural artifacts +- Reduced sharpness and realism + +LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to: + +- Improved image quality and consistency (6-20% FID improvement) +- Better preservation of fine details +- More stable training, especially at high noise levels +- Better handling of structural information +- Sharper and more realistic textures + +## Implementation Details + +The LPL implementation follows the paper's methodology and includes several key features: + +1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including: + - Middle block features + - Up block features (configurable number of blocks) + - Proper gradient checkpointing for memory efficiency + - Features are extracted only for timesteps below the threshold (high SNR) + +2. **Feature Normalization**: Multiple normalization options as validated in the paper: + - `default`: Normalize each feature map independently + - `shared`: Cross-normalize features using target statistics (recommended) + - `batch`: Batch-wise normalization + +3. **Outlier Handling**: Optional removal of outliers in feature maps using: + - Quantile-based filtering (2% quantiles) + - Morphological operations (opening/closing) + - Adaptive thresholding based on standard deviation + +4. **Loss Types**: + - MSE loss (default) + - L1 loss + - Optional power law weighting (2^(-i) for layer i) + +## Usage + +To use LPL in your training, add the following arguments to your training command: + +```bash +python examples/research_projects/lpl/train_sdxl_lpl.py \ + --use_lpl \ + --lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended) + --lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR) + --lpl_loss_type mse \ # Loss type: "mse" or "l1" + --lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch" + --lpl_pow_law \ # Use power law weighting for layers + --lpl_num_blocks 4 \ # Number of up blocks to use (1-4) + --lpl_remove_outliers \ # Remove outliers in feature maps + --lpl_scale \ # Scale LPL loss by noise level weights + --lpl_start 0 \ # Step to start applying LPL + # ... other training arguments ... +``` + +### Key Parameters + +- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training. +- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps. +- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases. +- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper. +- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance. +- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory. +- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training. +- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps. +- `lpl_start`: Training step to start applying LPL. Can be used to warm up training. + +## Recommendations + +1. **Starting Point** (based on paper results): + ```bash + --use_lpl \ + --lpl_weight 1.0 \ + --lpl_t_threshold 200 \ + --lpl_loss_type mse \ + --lpl_norm_type shared \ + --lpl_pow_law \ + --lpl_num_blocks 4 \ + --lpl_remove_outliers \ + --lpl_scale + ``` + +2. **Memory Efficiency**: + - Use `--gradient_checkpointing` for memory efficiency (enabled by default) + - Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results) + - Consider using `--lpl_scale` to focus on more important timesteps + - Features are extracted only for timesteps below threshold to save memory + +3. **Quality vs Speed**: + - Higher `lpl_weight` (1.0-2.0) for better quality + - Lower `lpl_t_threshold` (100-200) for faster training + - Use `lpl_remove_outliers` for more stable training + - `lpl_norm_type shared` provides best quality/speed trade-off + +## Technical Details + +### Feature Extraction + +The LPL implementation extracts features from the VAE decoder in the following order: +1. Middle block output +2. Up block outputs (configurable number of blocks) + +Each feature map is processed with: +1. Optional outlier removal (2% quantiles, morphological operations) +2. Feature normalization (shared statistics recommended) +3. Loss calculation (MSE or L1) +4. Optional power law weighting (2^(-i) for layer i) + +### Loss Calculation + +For each feature map: +1. Features are normalized according to the chosen strategy +2. Loss is calculated between normalized features +3. Outliers are masked out (if enabled) +4. Loss is weighted by layer depth (if power law enabled) +5. Final loss is averaged across all layers + +### Memory Considerations + +- Gradient checkpointing is used by default +- Features are extracted only for timesteps below the threshold +- Outlier removal is done in-place to save memory +- Feature normalization is done efficiently using vectorized operations +- Memory usage scales linearly with number of blocks used + +## Results + +Based on the paper's findings, LPL provides: +- 6-20% improvement in FID scores +- Better preservation of fine details +- More realistic textures and structures +- Improved consistency across different resolutions +- Better performance on both small and large datasets + +## Citation + +If you use this implementation in your research, please cite: + +```bibtex +@inproceedings{berrada2025boosting, + title={Boosting Latent Diffusion with Perceptual Objectives}, + author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek}, + booktitle={The Thirteenth International Conference on Learning Representations}, + year={2025}, + url={https://openreview.net/forum?id=y4DtzADzd1} +} +``` diff --git a/examples/research_projects/lpl/lpl_loss.py b/examples/research_projects/lpl/lpl_loss.py new file mode 100644 index 000000000000..de14a4d8d5aa --- /dev/null +++ b/examples/research_projects/lpl/lpl_loss.py @@ -0,0 +1,215 @@ +# Copyright 2025 Berrada et al. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + + +def cross_normalize(input, target, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True)) + return input / (norm_factor + eps), target / (norm_factor + eps) + + +def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02): + opening = int(np.ceil(opening / down_f)) + closing = int(np.ceil(closing / down_f)) + if opening == 2: + opening = 3 + if closing == 2: + closing = 1 + + # replace quantile with kth value here. + feat_flat = feat.flatten(-2, -1) + k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant)) + q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None] + q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None] + + m = 2 * feat_flat.std(-1)[..., None, None].detach() + mask = (q1 - m < feat) * (feat < q2 + m) + + # dilate the mask. + mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing + mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening + feat = feat * mask + return mask, feat + + +class LatentPerceptualLoss(nn.Module): + def __init__( + self, + vae, + loss_type="mse", + grad_ckpt=True, + pow_law=False, + norm_type="default", + num_mid_blocks=4, + feature_type="feature", + remove_outliers=True, + ): + super().__init__() + self.vae = vae + self.decoder = self.vae.decoder + # Store scaling factors as tensors on the correct device + device = next(self.vae.parameters()).device + + # Get scaling factors with proper defaults and handle None values + scale_factor = getattr(self.vae.config, "scaling_factor", None) + shift_factor = getattr(self.vae.config, "shift_factor", None) + + # Convert to tensors with proper defaults + self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device) + self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device) + + self.gradient_checkpointing = grad_ckpt + self.pow_law = pow_law + self.norm_type = norm_type.lower() + self.outlier_mask = remove_outliers + self.last_feature_stats = [] # Store feature statistics for logging + + assert feature_type in ["feature", "image"] + self.feature_type = feature_type + + assert self.norm_type in ["default", "shared", "batch"] + assert num_mid_blocks >= 0 and num_mid_blocks <= 4 + self.n_blocks = num_mid_blocks + + assert loss_type in ["mse", "l1"] + if loss_type == "mse": + self.loss_fn = nn.MSELoss(reduction="none") + elif loss_type == "l1": + self.loss_fn = nn.L1Loss(reduction="none") + + def get_features(self, z, latent_embeds=None, disable_grads=False): + with torch.set_grad_enabled(not disable_grads): + if self.gradient_checkpointing and not disable_grads: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + features = [] + upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype + sample = z + sample = self.decoder.conv_in(sample) + + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.decoder.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + features.append(sample) + + # up + for up_block in self.decoder.up_blocks[: self.n_blocks]: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + features.append(sample) + return features + else: + features = [] + upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype + sample = z + sample = self.decoder.conv_in(sample) + + # middle + sample = self.decoder.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + features.append(sample) + + # up + for up_block in self.decoder.up_blocks[: self.n_blocks]: + sample = up_block(sample, latent_embeds) + features.append(sample) + return features + + def get_loss(self, input, target, get_hist=False): + if self.feature_type == "feature": + inp_f = self.get_features(self.shift + input / self.scale) + tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True) + losses = [] + self.last_feature_stats = [] # Reset feature stats + + for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)): + my = torch.ones_like(y).bool() + outlier_ratio = 0.0 + + if self.outlier_mask: + with torch.no_grad(): + if i == 2: + my, y = remove_outliers(y, down_f=2) + outlier_ratio = 1.0 - my.float().mean().item() + elif i in [3, 4, 5]: + my, y = remove_outliers(y, down_f=1) + outlier_ratio = 1.0 - my.float().mean().item() + + # Store feature statistics before normalization + with torch.no_grad(): + stats = { + "mean": y.mean().item(), + "std": y.std().item(), + "outlier_ratio": outlier_ratio, + } + self.last_feature_stats.append(stats) + + # normalize feature tensors + if self.norm_type == "default": + x = normalize_tensor(x) + y = normalize_tensor(y) + elif self.norm_type == "shared": + x, y = cross_normalize(x, y, eps=1e-6) + + term_loss = self.loss_fn(x, y) * my + # reduce loss term + loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0 + term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3)) + losses.append(term_loss.mean((1,))) + + if get_hist: + return losses + else: + loss = sum(losses) + return loss / len(inp_f) + elif self.feature_type == "image": + inp_f = self.vae.decode(input / self.scale).sample + tar_f = self.vae.decode(target / self.scale).sample + return F.mse_loss(inp_f, tar_f) + + def get_first_conv(self, z): + sample = self.decoder.conv_in(z) + return sample + + def get_first_block(self, z): + sample = self.decoder.conv_in(z) + sample = self.decoder.mid_block(sample) + for resnet in self.decoder.up_blocks[0].resnets: + sample = resnet(sample, None) + return sample + + def get_first_layer(self, input, target, target_layer="conv"): + if target_layer == "conv": + feat_in = self.get_first_conv(input) + with torch.no_grad(): + feat_tar = self.get_first_conv(target) + else: + feat_in = self.get_first_block(input) + with torch.no_grad(): + feat_tar = self.get_first_block(target) + + feat_in, feat_tar = cross_normalize(feat_in, feat_tar) + + return F.mse_loss(feat_in, feat_tar, reduction="mean") diff --git a/examples/research_projects/lpl/train_sdxl_lpl.py b/examples/research_projects/lpl/train_sdxl_lpl.py new file mode 100644 index 000000000000..4c472c8871c0 --- /dev/null +++ b/examples/research_projects/lpl/train_sdxl_lpl.py @@ -0,0 +1,1622 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LPL training script for Stable Diffusion XL for text2image.""" + +import argparse +import functools +import gc +import logging +import math +import os +import random +import re +import shutil +from contextlib import nullcontext +from pathlib import Path +from typing import Dict, List, Tuple + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedType, ProjectConfiguration, set_seed +from datasets import concatenate_datasets, load_dataset +from huggingface_hub import create_repo, upload_folder +from lpl_loss import LatentPerceptualLoss +from packaging import version +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.33.0.dev0") + +logger = get_logger(__name__) +if is_torch_npu_available(): + import torch_npu + + torch.npu.config.allow_internal_format = False + +DATASET_NAME_MAPPING = { + "lambdalabs/naruto-blip-captions": ("image", "text"), +} + +# Global dictionary to store intermediate features from hooks +hook_features: Dict[str, torch.Tensor] = {} + + +def get_intermediate_features_hook(name: str): + """Creates a hook function that saves the output of a layer.""" + + def hook(model, input, output): + # Some layers might return tuples (e.g., attention blocks) + # We are usually interested in the first element (hidden states) + if isinstance(output, tuple): + hook_features[name] = output[0] + else: + hook_features[name] = output + + return hook + + +def clear_hook_features(): + """Clears the global feature dictionary.""" + global hook_features + hook_features = {} + + +def normalize_features( + feat1: torch.Tensor, feat2: torch.Tensor, eps: float = 1e-6 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Normalizes feat1 and feat2 using the statistics of feat2 (predicted features). + Normalization is done per-channel. + """ + # Calculate stats over spatial dimensions (H, W) + dims = tuple(range(2, feat2.ndim)) # Dims to reduce over (usually 2, 3 for H, W) + mean = torch.mean(feat2, dim=dims, keepdim=True) + std = torch.std(feat2, dim=dims, keepdim=True) + eps + + feat1_norm = (feat1 - mean) / std + feat2_norm = (feat2 - mean) / std + return feat1_norm, feat2_norm + + +def get_decoder_layer_names(decoder: nn.Module) -> List[str]: + """Helper to get potential layer names for hooks in the VAE decoder.""" + layer_names = [] + for name, module in decoder.named_modules(): + # Example: Target ResnetBlocks and potentially UpBlocks + if isinstance(module, (diffusers.models.resnet.ResnetBlock2D, diffusers.models.unet_2d_blocks.UpBlock2D)): + # Filter out redundant names if UpBlock contains ResnetBlocks already named + is_child = any( + name.startswith(parent + ".") + for parent in layer_names + if isinstance(decoder.get_submodule(parent), diffusers.models.unet_2d_blocks.UpBlock2D) + ) + if not is_child: + layer_names.append(name) + # A basic default selection if complex logic fails + if not layer_names: + layer_names = [ + name for name, module in decoder.named_modules() if re.match(r"up_blocks\.\d+\.resnets\.\d+$", name) + ] + return layer_names + + +def save_model_card( + repo_id: str, + images: list = None, + validation_prompt: str = None, + base_model: str = None, + dataset_name: str = None, + repo_folder: str = None, + vae_path: str = None, +): + img_str = "" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + model_description = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n +{img_str} + +Special VAE used for training: {vae_path}. +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion-xl", + "stable-diffusion-xl-diffusers", + "text-to-image", + "diffusers-training", + "diffusers", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="LPL based training script of Stable Diffusion XL.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=1, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--output_dir", + type=str, + default="sdxl-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--timestep_bias_strategy", + type=str, + default="none", + choices=["earlier", "later", "range", "none"], + help=( + "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + "--timestep_bias_multiplier", + type=float, + default=1.0, + help=( + "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." + " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." + ), + ) + parser.add_argument( + "--timestep_bias_begin", + type=int, + default=0, + help=( + "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." + " Defaults to zero, which equates to having no specific bias." + ), + ) + parser.add_argument( + "--timestep_bias_end", + type=int, + default=1000, + help=( + "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." + " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." + ), + ) + parser.add_argument( + "--timestep_bias_portion", + type=float, + default=0.25, + help=( + "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." + " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" + " whether the biased portions are in the earlier or later timesteps." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention." + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + parser.add_argument( + "--use_lpl", + action="store_true", + help="Whether to use Latent Perceptual Loss (LPL). Increases memory usage.", + ) + parser.add_argument( + "--lpl_weight", + type=float, + default=1.0, + help="Weight for the Latent Perceptual Loss.", + ) + parser.add_argument( + "--lpl_t_threshold", + type=int, + default=200, + help="Apply LPL only for timesteps t < lpl_t_threshold. Corresponds to high SNR.", + ) + parser.add_argument( + "--lpl_loss_type", + type=str, + default="mse", + choices=["mse", "l1"], + help="Type of loss to use for LPL.", + ) + parser.add_argument( + "--lpl_norm_type", + type=str, + default="default", + choices=["default", "shared", "batch"], + help="Type of normalization to use for LPL features.", + ) + parser.add_argument( + "--lpl_pow_law", + action="store_true", + help="Whether to use power law weighting for LPL layers.", + ) + parser.add_argument( + "--lpl_num_blocks", + type=int, + default=4, + help="Number of up blocks to use for LPL feature extraction.", + ) + parser.add_argument( + "--lpl_remove_outliers", + action="store_true", + help="Whether to remove outliers in LPL feature maps.", + ) + parser.add_argument( + "--lpl_scale", + action="store_true", + help="Whether to scale LPL loss by noise level weights.", + ) + parser.add_argument( + "--lpl_start", + type=int, + default=0, + help="Step to start applying LPL loss.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True): + prompt_embeds_list = [] + prompt_batch = batch[caption_column] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + return_dict=False, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-1][-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} + + +def compute_vae_encodings(batch, vae): + images = batch.pop("pixel_values") + pixel_values = torch.stack(list(images)) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + pixel_values = pixel_values.to(vae.device, dtype=vae.dtype) + + with torch.no_grad(): + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + + # There might have slightly performance improvement + # by changing model_input.cpu() to accelerator.gather(model_input) + return {"model_input": model_input.cpu()} + + +def generate_timestep_weights(args, num_timesteps): + weights = torch.ones(num_timesteps) + + # Determine the indices to bias + num_to_bias = int(args.timestep_bias_portion * num_timesteps) + + if args.timestep_bias_strategy == "later": + bias_indices = slice(-num_to_bias, None) + elif args.timestep_bias_strategy == "earlier": + bias_indices = slice(0, num_to_bias) + elif args.timestep_bias_strategy == "range": + # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. + range_begin = args.timestep_bias_begin + range_end = args.timestep_bias_end + if range_begin < 0: + raise ValueError( + "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero." + ) + if range_end > num_timesteps: + raise ValueError( + "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps." + ) + bias_indices = slice(range_begin, range_end) + else: # 'none' or any other string + return weights + if args.timestep_bias_multiplier <= 0: + return ValueError( + "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps." + " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead." + " A timestep bias multiplier less than or equal to 0 is not allowed." + ) + + # Apply the bias + weights[bias_indices] *= args.timestep_bias_multiplier + + # Normalize + weights /= weights.sum() + + return weights + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + # Check for terminal SNR in combination with SNR Gamma + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # Freeze vae and text encoders. + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + # Set unet as trainable. + unet.train() + + # For mixed precision training we cast all non-trainable weights to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + unet.enable_npu_flash_attention() + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.") + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for _ in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = unet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + # image aug + original_sizes = [] + all_images = [] + crop_top_lefts = [] + for image in images: + original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + all_images.append(image) + + examples["original_sizes"] = original_sizes + examples["crop_top_lefts"] = crop_top_lefts + examples["pixel_values"] = all_images + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. We will pre-compute the VAE encodings too. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + compute_embeddings_fn = functools.partial( + encode_prompt, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + caption_column=args.caption_column, + ) + compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + new_fingerprint_for_vae = Hasher.hash((vae_path, args)) + train_dataset_with_embeddings = train_dataset.map( + compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint + ) + train_dataset_with_vae = train_dataset.map( + compute_vae_encodings_fn, + batched=True, + batch_size=args.train_batch_size, + new_fingerprint=new_fingerprint_for_vae, + ) + precomputed_dataset = concatenate_datasets( + [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1 + ) + precomputed_dataset = precomputed_dataset.with_transform(preprocess_train) + + del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two + del text_encoders, tokenizers + if not args.use_lpl: + del vae + gc.collect() + + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + elif torch.cuda.is_available(): + torch.cuda.empty_cache() + + def collate_fn(examples): + model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) + original_sizes = [example["original_sizes"] for example in examples] + crop_top_lefts = [example["crop_top_lefts"] for example in examples] + prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) + + return { + "model_input": model_input, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + precomputed_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) + + if args.use_lpl: + lpl_fn = LatentPerceptualLoss( + vae=vae, + loss_type=args.lpl_loss_type, + grad_ckpt=args.gradient_checkpointing, + pow_law=args.lpl_pow_law, + norm_type=args.lpl_norm_type, + num_mid_blocks=args.lpl_num_blocks, + feature_type="feature", + remove_outliers=args.lpl_remove_outliers, + ) + lpl_fn.to(accelerator.device) + else: + lpl_fn = None + + # Function for unwrapping if torch.compile() was used in accelerate. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path: + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(precomputed_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # Get scheduler alphas and sigmas for LPL z0_hat calculation + alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # Sample noise that we'll add to the latents + model_input = batch["model_input"].to(accelerator.device) + noise = torch.randn_like(model_input) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device + ) + + bsz = model_input.shape[0] + if args.timestep_bias_strategy == "none": + # Sample a random timestep for each image without bias. + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + else: + # Sample a random timestep for each image, potentially biased by the timestep weights. + # Biasing the timestep weights allows us to spend less time training irrelevant timesteps. + weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to( + model_input.device + ) + timesteps = torch.multinomial(weights, bsz, replacement=True).long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype) + + # time ids + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype) + return add_time_ids + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + ) + + # Predict the noise residual + unet_added_conditions = {"time_ids": add_time_ids} + prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + elif noise_scheduler.config.prediction_type == "sample": + # We set the target to latents here, but the model_pred will return the noise sample prediction. + target = model_input + # We will have to subtract the noise residual from the prediction to get the target sample. + model_pred = model_pred - noise + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + if noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = mse_loss_weights / snr + elif noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights / (snr + 1) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + lpl_loss_value = torch.tensor(0.0, device=accelerator.device) + if args.use_lpl and lpl_fn is not None and global_step >= args.lpl_start: + # Apply LPL only below the timestep threshold + lpl_mask = timesteps < args.lpl_t_threshold + if lpl_mask.any(): + # Select samples that meet the threshold + masked_indices = torch.where(lpl_mask)[0] + z0_masked = model_input[masked_indices] + zt_masked = noisy_model_input[masked_indices] + t_masked = timesteps[masked_indices] + model_pred_masked = model_pred[masked_indices] + + # Calculate z0_hat for the masked samples + alpha_t = alphas_cumprod[t_masked].sqrt().to(torch.float32) + sigma_t = (1 - alphas_cumprod[t_masked]).sqrt().to(torch.float32) + alpha_t = alpha_t.view(-1, 1, 1, 1) + sigma_t = sigma_t.view(-1, 1, 1, 1) + + if noise_scheduler.config.prediction_type == "epsilon": + z0_hat_masked = (zt_masked.float() - sigma_t * model_pred_masked.float()) / alpha_t + elif noise_scheduler.config.prediction_type == "v_prediction": + z0_hat_masked = alpha_t * zt_masked.float() - sigma_t * model_pred_masked.float() + else: # sample prediction + z0_hat_masked = model_pred_masked.float() + + with accelerator.autocast(): + lpl_loss_value = lpl_fn.get_loss(z0_hat_masked, z0_masked) + + if args.lpl_scale: + if args.snr_gamma is not None: + # Use SNR-based weights if available + snr = compute_snr(noise_scheduler, t_masked) + snr_weights = torch.stack( + [snr, args.snr_gamma * torch.ones_like(t_masked)], dim=1 + ).min(dim=1)[0] + if noise_scheduler.config.prediction_type == "epsilon": + snr_weights = snr_weights / snr + elif noise_scheduler.config.prediction_type == "v_prediction": + snr_weights = snr_weights / (snr + 1) + lpl_loss_value = (lpl_loss_value * snr_weights).mean() + else: + # If no SNR weighting, just use mean + lpl_loss_value = lpl_loss_value.mean() + else: + lpl_loss_value = lpl_loss_value.mean() + + # Combine losses + total_loss = loss + args.lpl_weight * lpl_loss_value + + # Gather the losses across all processes for logging + avg_loss = accelerator.gather(total_loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(total_loss) + if accelerator.sync_gradients: + params_to_clip = unet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + + # Enhanced logging for LPL metrics + log_data = { + "train_loss": train_loss, + "diffusion_loss": loss.item(), + "learning_rate": lr_scheduler.get_last_lr()[0], + } + + if args.use_lpl and lpl_fn is not None and global_step >= args.lpl_start: + if lpl_mask.any(): + # LPL application statistics + log_data.update( + { + "lpl/loss": lpl_loss_value.item(), + "lpl/num_samples": lpl_mask.sum().item(), + "lpl/application_ratio": lpl_mask.float().mean().item(), + "lpl/weight": args.lpl_weight, + "lpl/weighted_loss": (args.lpl_weight * lpl_loss_value).item(), + } + ) + + # SNR statistics for LPL-applied samples + if args.snr_gamma is not None: + snr_values = snr[masked_indices] + log_data.update( + { + "lpl/snr_mean": snr_values.mean().item(), + "lpl/snr_std": snr_values.std().item(), + "lpl/snr_min": snr_values.min().item(), + "lpl/snr_max": snr_values.max().item(), + } + ) + + # Feature statistics if available + if hasattr(lpl_fn, "last_feature_stats"): + for layer_idx, stats in enumerate(lpl_fn.last_feature_stats): + log_data.update( + { + f"lpl/features/layer_{layer_idx}/mean": stats["mean"], + f"lpl/features/layer_{layer_idx}/std": stats["std"], + f"lpl/features/layer_{layer_idx}/outlier_ratio": stats.get( + "outlier_ratio", 0.0 + ), + } + ) + + # Memory usage if available + if torch.cuda.is_available(): + log_data.update( + { + "lpl/memory/allocated": torch.cuda.memory_allocated() / 1024**2, # MB + "lpl/memory/reserved": torch.cuda.memory_reserved() / 1024**2, # MB + } + ) + + # Log to accelerator + accelerator.log(log_data, step=global_step) + + # Update progress bar with more metrics + progress_bar_logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + if args.use_lpl and lpl_loss_value.item() > 0: + progress_bar_logs.update( + { + "lpl": lpl_loss_value.item(), + "lpl_ratio": lpl_mask.float().mean().item() if lpl_mask.any() else 0.0, + } + ) + progress_bar.set_postfix(**progress_bar_logs) + + # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues. + if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + # create pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=accelerator.unwrap_model(unet), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if args.prediction_type is not None: + scheduler_args = {"prediction_type": args.prediction_type} + pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) + if args.seed is not None + else None + ) + pipeline_args = {"prompt": args.validation_prompt} + + with autocast_ctx: + images = [ + pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + elif torch.cuda.is_available(): + torch.cuda.empty_cache() + + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + # Serialize pipeline. + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unet, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if args.prediction_type is not None: + scheduler_args = {"prediction_type": args.prediction_type} + pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.save_pretrained(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline = pipeline.to(accelerator.device) + generator = ( + torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + ) + + with autocast_ctx: + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + save_model_card( + repo_id=repo_id, + images=images, + validation_prompt=args.validation_prompt, + base_model=args.pretrained_model_name_or_path, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/research_projects/onnxruntime/text_to_image/README.md b/examples/research_projects/onnxruntime/text_to_image/README.md index f398f081663a..1d688471ba74 100644 --- a/examples/research_projects/onnxruntime/text_to_image/README.md +++ b/examples/research_projects/onnxruntime/text_to_image/README.md @@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode ___Note___: -___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ +___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___ ## Running locally with PyTorch diff --git a/examples/research_projects/sdxl_flax/sdxl_single.py b/examples/research_projects/sdxl_flax/sdxl_single.py index 5b9b862d99b5..c3cbf6ca24f0 100644 --- a/examples/research_projects/sdxl_flax/sdxl_single.py +++ b/examples/research_projects/sdxl_flax/sdxl_single.py @@ -18,7 +18,7 @@ NUM_DEVICES = jax.device_count() # 1. Let's start by downloading the model and loading it into our pipeline class -# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and +# Adhering to JAX's functional approach, the model's parameters are returned separately and # will have to be passed to the pipeline during inference pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True diff --git a/examples/server-async/utils/requestscopedpipeline.py b/examples/server-async/utils/requestscopedpipeline.py index 57d1e2567169..9c3276c31c69 100644 --- a/examples/server-async/utils/requestscopedpipeline.py +++ b/examples/server-async/utils/requestscopedpipeline.py @@ -7,16 +7,12 @@ from diffusers.utils import logging from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps +from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper logger = logging.get_logger(__name__) -def safe_tokenize(tokenizer, *args, lock, **kwargs): - with lock: - return tokenizer(*args, **kwargs) - - class RequestScopedPipeline: DEFAULT_MUTABLE_ATTRS = [ "_all_hooks", @@ -38,23 +34,40 @@ def __init__( wrap_scheduler: bool = True, ): self._base = pipeline + self.unet = getattr(pipeline, "unet", None) self.vae = getattr(pipeline, "vae", None) self.text_encoder = getattr(pipeline, "text_encoder", None) self.components = getattr(pipeline, "components", None) + self.transformer = getattr(pipeline, "transformer", None) + if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None: if not isinstance(pipeline.scheduler, BaseAsyncScheduler): pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler) self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS) + self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock() + self._vae_lock = threading.Lock() + self._image_lock = threading.Lock() + self._auto_detect_mutables = bool(auto_detect_mutables) self._tensor_numel_threshold = int(tensor_numel_threshold) - self._auto_detected_attrs: List[str] = [] + def _detect_kernel_pipeline(self, pipeline) -> bool: + kernel_indicators = [ + "text_encoding_cache", + "memory_manager", + "enable_optimizations", + "_create_request_context", + "get_optimization_stats", + ] + + return any(hasattr(pipeline, attr) for attr in kernel_indicators) + def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs): base_sched = getattr(self._base, "scheduler", None) if base_sched is None: @@ -70,11 +83,21 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] num_inference_steps=num_inference_steps, device=device, **clone_kwargs ) except Exception as e: - logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()") + logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback") try: - return copy.deepcopy(wrapped_scheduler) - except Exception as e: - logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).") + if hasattr(wrapped_scheduler, "scheduler"): + try: + copied_scheduler = copy.copy(wrapped_scheduler.scheduler) + return BaseAsyncScheduler(copied_scheduler) + except Exception: + return wrapped_scheduler + else: + copied_scheduler = copy.copy(wrapped_scheduler) + return BaseAsyncScheduler(copied_scheduler) + except Exception as e2: + logger.warning( + f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)." + ) return wrapped_scheduler def _autodetect_mutables(self, max_attrs: int = 40): @@ -86,6 +109,7 @@ def _autodetect_mutables(self, max_attrs: int = 40): candidates: List[str] = [] seen = set() + for name in dir(self._base): if name.startswith("__"): continue @@ -93,6 +117,7 @@ def _autodetect_mutables(self, max_attrs: int = 40): continue if name in ("to", "save_pretrained", "from_pretrained"): continue + try: val = getattr(self._base, name) except Exception: @@ -100,11 +125,9 @@ def _autodetect_mutables(self, max_attrs: int = 40): import types - # skip callables and modules if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)): continue - # containers -> candidate if isinstance(val, (dict, list, set, tuple, bytearray)): candidates.append(name) seen.add(name) @@ -205,6 +228,9 @@ def _is_tokenizer_component(self, component) -> bool: return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs) + def _should_wrap_tokenizers(self) -> bool: + return True + def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs): local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device) @@ -214,6 +240,25 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).") local_pipe = copy.deepcopy(self._base) + try: + if ( + hasattr(local_pipe, "vae") + and local_pipe.vae is not None + and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper) + ): + local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock) + + if ( + hasattr(local_pipe, "image_processor") + and local_pipe.image_processor is not None + and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper) + ): + local_pipe.image_processor = ThreadSafeImageProcessorWrapper( + local_pipe.image_processor, self._image_lock + ) + except Exception as e: + logger.debug(f"Could not wrap vae/image_processor: {e}") + if local_scheduler is not None: try: timesteps, num_steps, configured_scheduler = async_retrieve_timesteps( @@ -231,47 +276,42 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = self._clone_mutable_attrs(self._base, local_pipe) - # 4) wrap tokenizers on the local pipe with the lock wrapper - tokenizer_wrappers = {} # name -> original_tokenizer - try: - # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...) - for name in dir(local_pipe): - if "tokenizer" in name and not name.startswith("_"): - tok = getattr(local_pipe, name, None) - if tok is not None and self._is_tokenizer_component(tok): - tokenizer_wrappers[name] = tok - setattr( - local_pipe, - name, - lambda *args, tok=tok, **kwargs: safe_tokenize( - tok, *args, lock=self._tokenizer_lock, **kwargs - ), - ) - - # b) wrap tokenizers in components dict - if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): - for key, val in local_pipe.components.items(): - if val is None: - continue - - if self._is_tokenizer_component(val): - tokenizer_wrappers[f"components[{key}]"] = val - local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize( - tokenizer, *args, lock=self._tokenizer_lock, **kwargs - ) + original_tokenizers = {} - except Exception as e: - logger.debug(f"Tokenizer wrapping step encountered an error: {e}") + if self._should_wrap_tokenizers(): + try: + for name in dir(local_pipe): + if "tokenizer" in name and not name.startswith("_"): + tok = getattr(local_pipe, name, None) + if tok is not None and self._is_tokenizer_component(tok): + if not isinstance(tok, ThreadSafeTokenizerWrapper): + original_tokenizers[name] = tok + wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock) + setattr(local_pipe, name, wrapped_tokenizer) + + if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): + for key, val in local_pipe.components.items(): + if val is None: + continue + + if self._is_tokenizer_component(val): + if not isinstance(val, ThreadSafeTokenizerWrapper): + original_tokenizers[f"components[{key}]"] = val + wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock) + local_pipe.components[key] = wrapped_tokenizer + + except Exception as e: + logger.debug(f"Tokenizer wrapping step encountered an error: {e}") result = None cm = getattr(local_pipe, "model_cpu_offload_context", None) + try: if callable(cm): try: with cm(): result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) except TypeError: - # cm might be a context manager instance rather than callable try: with cm: result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) @@ -279,18 +319,18 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.") result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) else: - # no offload context available — call directly result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs) return result finally: try: - for name, tok in tokenizer_wrappers.items(): + for name, tok in original_tokenizers.items(): if name.startswith("components["): key = name[len("components[") : -1] - local_pipe.components[key] = tok + if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict): + local_pipe.components[key] = tok else: setattr(local_pipe, name, tok) except Exception as e: - logger.debug(f"Error restoring wrapped tokenizers: {e}") + logger.debug(f"Error restoring original tokenizers: {e}") diff --git a/examples/server-async/utils/wrappers.py b/examples/server-async/utils/wrappers.py new file mode 100644 index 000000000000..1e8474eabf3f --- /dev/null +++ b/examples/server-async/utils/wrappers.py @@ -0,0 +1,86 @@ +class ThreadSafeTokenizerWrapper: + def __init__(self, tokenizer, lock): + self._tokenizer = tokenizer + self._lock = lock + + self._thread_safe_methods = { + "__call__", + "encode", + "decode", + "tokenize", + "encode_plus", + "batch_encode_plus", + "batch_decode", + } + + def __getattr__(self, name): + attr = getattr(self._tokenizer, name) + + if name in self._thread_safe_methods and callable(attr): + + def wrapped_method(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + return wrapped_method + + return attr + + def __call__(self, *args, **kwargs): + with self._lock: + return self._tokenizer(*args, **kwargs) + + def __setattr__(self, name, value): + if name.startswith("_"): + super().__setattr__(name, value) + else: + setattr(self._tokenizer, name, value) + + def __dir__(self): + return dir(self._tokenizer) + + +class ThreadSafeVAEWrapper: + def __init__(self, vae, lock): + self._vae = vae + self._lock = lock + + def __getattr__(self, name): + attr = getattr(self._vae, name) + if name in {"decode", "encode", "forward"} and callable(attr): + + def wrapped(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + return wrapped + return attr + + def __setattr__(self, name, value): + if name.startswith("_"): + super().__setattr__(name, value) + else: + setattr(self._vae, name, value) + + +class ThreadSafeImageProcessorWrapper: + def __init__(self, proc, lock): + self._proc = proc + self._lock = lock + + def __getattr__(self, name): + attr = getattr(self._proc, name) + if name in {"postprocess", "preprocess"} and callable(attr): + + def wrapped(*args, **kwargs): + with self._lock: + return attr(*args, **kwargs) + + return wrapped + return attr + + def __setattr__(self, name, value): + if name.startswith("_"): + super().__setattr__(name, value) + else: + setattr(self._proc, name, value) diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index 989ac6e0c45e..32128ebbd4df 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 7ebf7b5465a5..90dd06d33c5e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -57,7 +57,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index c4f36879f328..e474445d9afe 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -49,7 +49,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 1fd48dcd159d..310a50ac4e9a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 5fb1825f37d3..88f5c3cede6e 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -68,7 +68,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index c26cb4484125..4eafa8f28a19 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -55,7 +55,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index caa77e4bbaf5..0d8c25349fca 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -82,7 +82,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 4a03d9bf6ba9..7fb394a1bd15 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -56,7 +56,7 @@ # ------------------------------------------------------------------------------ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 51de29a71a47..3f482341ca4a 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -77,7 +77,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 0cc96220b932..ed7d2db43700 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -29,7 +29,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index eeb592a3f7d9..d9ad2774e897 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -50,7 +50,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.36.0.dev0") +check_min_version("0.37.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py index 6f6563ad641b..bc6014068e87 100644 --- a/scripts/convert_cosmos_to_diffusers.py +++ b/scripts/convert_cosmos_to_diffusers.py @@ -1,11 +1,94 @@ +""" +# Cosmos 2 Predict + +Download checkpoint +```bash +hf download nvidia/Cosmos-Predict2-2B-Text2Image +``` + +convert checkpoint +```bash +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_ckpt_path $transformer_ckpt_path \ + --transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \ + --text_encoder_path google-t5/t5-11b \ + --tokenizer_path google-t5/t5-11b \ + --vae_type wan2.1 \ + --output_path converted/cosmos-p2-t2i-2b \ + --save_pipeline +``` + +# Cosmos 2.5 Predict + +Download checkpoint +```bash +hf download nvidia/Cosmos-Predict2.5-2B +``` + +Convert checkpoint +```bash +# pre-trained +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \ + --save_pipeline + +# post-trained +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \ + --save_pipeline +``` + +## 14B + +```bash +hf download nvidia/Cosmos-Predict2.5-14B +``` + +```bash +# pre-trained +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-14B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \ + --save_pipeline + +# post-trained +transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Predict-Base-14B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \ + --save_pipeline +``` + +""" + import argparse import pathlib +import sys from typing import Any, Dict import torch from accelerate import init_empty_weights from huggingface_hub import snapshot_download -from transformers import T5EncoderModel, T5TokenizerFast +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast from diffusers import ( AutoencoderKLCosmos, @@ -17,7 +100,9 @@ CosmosVideoToWorldPipeline, EDMEulerScheduler, FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, ) +from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline def remove_keys_(key: str, state_dict: Dict[str, Any]): @@ -233,6 +318,44 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]): "concat_padding_mask": True, "extra_pos_embed_type": None, }, + "Cosmos-2.5-Predict-Base-2B": { + "in_channels": 16 + 1, + "out_channels": 16, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 28, + "mlp_ratio": 4.0, + "text_embed_dim": 1024, + "adaln_lora_dim": 256, + "max_size": (128, 240, 240), + "patch_size": (1, 2, 2), + "rope_scale": (1.0, 3.0, 3.0), + "concat_padding_mask": True, + # NOTE: source config has pos_emb_learnable: 'True' - but params are missing + "extra_pos_embed_type": None, + "use_crossattn_projection": True, + "crossattn_proj_in_channels": 100352, + "encoder_hidden_states_channels": 1024, + }, + "Cosmos-2.5-Predict-Base-14B": { + "in_channels": 16 + 1, + "out_channels": 16, + "num_attention_heads": 40, + "attention_head_dim": 128, + "num_layers": 36, + "mlp_ratio": 4.0, + "text_embed_dim": 1024, + "adaln_lora_dim": 256, + "max_size": (128, 240, 240), + "patch_size": (1, 2, 2), + "rope_scale": (1.0, 3.0, 3.0), + "concat_padding_mask": True, + # NOTE: source config has pos_emb_learnable: 'True' - but params are missing + "extra_pos_embed_type": None, + "use_crossattn_projection": True, + "crossattn_proj_in_channels": 100352, + "encoder_hidden_states_channels": 1024, + }, } VAE_KEYS_RENAME_DICT = { @@ -334,6 +457,9 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo elif "Cosmos-2.0" in transformer_type: TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 + elif "Cosmos-2.5" in transformer_type: + TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 + TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 else: assert False @@ -347,6 +473,7 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo new_key = new_key.removeprefix(PREFIX_KEY) for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) + print(key, "->", new_key, flush=True) update_state_dict_(original_state_dict, key, new_key) for key in list(original_state_dict.keys()): @@ -355,6 +482,21 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo continue handler_fn_inplace(key, original_state_dict) + expected_keys = set(transformer.state_dict().keys()) + mapped_keys = set(original_state_dict.keys()) + missing_keys = expected_keys - mapped_keys + unexpected_keys = mapped_keys - expected_keys + if missing_keys: + print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr) + for k in missing_keys: + print(k) + sys.exit(1) + if unexpected_keys: + print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr) + for k in unexpected_keys: + print(k) + sys.exit(2) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer @@ -444,6 +586,34 @@ def save_pipeline_cosmos_2_0(args, transformer, vae): pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") +def save_pipeline_cosmos2_5(args, transformer, vae): + text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B" + tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct" + + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + text_encoder_path, torch_dtype="auto", device_map="cpu" + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + scheduler = UniPCMultistepScheduler( + use_karras_sigmas=True, + use_flow_sigmas=True, + prediction_type="flow_prediction", + sigma_max=200.0, + sigma_min=0.01, + ) + + pipe = Cosmos2_5_PredictBasePipeline( + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + vae=vae, + scheduler=scheduler, + safety_checker=lambda *args, **kwargs: None, + ) + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys())) @@ -451,10 +621,10 @@ def get_args(): "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" ) parser.add_argument( - "--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE" + "--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE" ) - parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b") - parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b") + parser.add_argument("--text_encoder_path", type=str, default=None) + parser.add_argument("--tokenizer_path", type=str, default=None) parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") @@ -477,8 +647,6 @@ def get_args(): if args.save_pipeline: assert args.transformer_ckpt_path is not None assert args.vae_type is not None - assert args.text_encoder_path is not None - assert args.tokenizer_path is not None if args.transformer_ckpt_path is not None: weights_only = "Cosmos-1.0" in args.transformer_type @@ -490,17 +658,26 @@ def get_args(): if args.vae_type is not None: if "Cosmos-1.0" in args.transformer_type: vae = convert_vae(args.vae_type) - else: + elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type: vae = AutoencoderKLWan.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 ) + else: + raise AssertionError(f"{args.transformer_type} not supported") + if not args.save_pipeline: vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") if args.save_pipeline: if "Cosmos-1.0" in args.transformer_type: + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None save_pipeline_cosmos_1_0(args, transformer, vae) elif "Cosmos-2.0" in args.transformer_type: + assert args.text_encoder_path is not None + assert args.tokenizer_path is not None save_pipeline_cosmos_2_0(args, transformer, vae) + elif "Cosmos-2.5" in args.transformer_type: + save_pipeline_cosmos2_5(args, transformer, vae) else: - assert False + raise AssertionError(f"{args.transformer_type} not supported") diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py index 2973913fa215..a8fa6f87eee1 100644 --- a/scripts/convert_flux2_to_diffusers.py +++ b/scripts/convert_flux2_to_diffusers.py @@ -44,7 +44,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--original_state_dict_repo_id", default=None, type=str) parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str) -parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str) +parser.add_argument("--dit_filename", default="flux2-dev.safetensors", type=str) parser.add_argument("--vae", action="store_true") parser.add_argument("--dit", action="store_true") parser.add_argument("--vae_dtype", type=str, default="fp32") @@ -385,9 +385,9 @@ def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: - if model_type == "test" or model_type == "dummy-flux2": + if model_type == "flux2-dev": config = { - "model_id": "diffusers-internal-dev/dummy-flux2", + "model_id": "black-forest-labs/FLUX.2-dev", "diffusers_config": { "patch_size": 1, "in_channels": 128, @@ -405,6 +405,53 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "klein-4b": + config = { + "model_id": "diffusers-internal-dev/dummy0115", + "diffusers_config": { + "patch_size": 1, + "in_channels": 128, + "num_layers": 5, + "num_single_layers": 20, + "attention_head_dim": 128, + "num_attention_heads": 24, + "joint_attention_dim": 7680, + "timestep_guidance_channels": 256, + "mlp_ratio": 3.0, + "axes_dims_rope": (32, 32, 32, 32), + "rope_theta": 2000, + "eps": 1e-6, + "guidance_embeds": False, + }, + } + rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + + elif model_type == "klein-9b": + config = { + "model_id": "diffusers-internal-dev/dummy0115", + "diffusers_config": { + "patch_size": 1, + "in_channels": 128, + "num_layers": 8, + "num_single_layers": 24, + "attention_head_dim": 128, + "num_attention_heads": 32, + "joint_attention_dim": 12288, + "timestep_guidance_channels": 256, + "mlp_ratio": 3.0, + "axes_dims_rope": (32, 32, 32, 32), + "rope_theta": 2000, + "eps": 1e-6, + "guidance_embeds": False, + }, + } + rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP + + else: + raise ValueError(f"Unknown model_type: {model_type}. Choose from: flux2-dev, klein-4b, klein-9b") + return config, rename_dict, special_keys_remap @@ -447,7 +494,14 @@ def main(args): if args.dit: original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename) - transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test") + + if "klein-4b" in args.dit_filename: + model_type = "klein-4b" + elif "klein-9b" in args.dit_filename: + model_type = "klein-9b" + else: + model_type = "flux2-dev" + transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, model_type) if not args.full_pipe: dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32 transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer") @@ -465,8 +519,15 @@ def main(args): "black-forest-labs/FLUX.1-dev", subfolder="scheduler" ) + if_distilled = "base" not in args.dit_filename + pipe = Flux2Pipeline( - vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + if_distilled=if_distilled, ) pipe.save_pretrained(args.output_path) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 38226f684a6d..89e5cdb16956 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -69,6 +69,11 @@ "target_size": 960, "task_type": "i2v", }, + "480p_i2v_step_distilled": { + "target_size": 640, + "task_type": "i2v", + "use_meanflow": True, + }, } SCHEDULER_CONFIGS = { @@ -93,6 +98,9 @@ "720p_i2v_distilled": { "shift": 7.0, }, + "480p_i2v_step_distilled": { + "shift": 7.0, + }, } GUIDANCE_CONFIGS = { @@ -117,6 +125,9 @@ "720p_i2v_distilled": { "guidance_scale": 1.0, }, + "480p_i2v_step_distilled": { + "guidance_scale": 1.0, + }, } @@ -126,7 +137,7 @@ def swap_scale_shift(weight): return new_weight -def convert_hyvideo15_transformer_to_diffusers(original_state_dict): +def convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=None): """ Convert HunyuanVideo 1.5 original checkpoint to Diffusers format. """ @@ -142,6 +153,20 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict): ) converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias") + if config.use_meanflow: + converted_state_dict["time_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop( + "time_r_in.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop( + "time_r_in.mlp.0.bias" + ) + converted_state_dict["time_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop( + "time_r_in.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop( + "time_r_in.mlp.2.bias" + ) + # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = ( original_state_dict.pop("txt_in.t_embedder.mlp.0.weight") @@ -627,7 +652,7 @@ def convert_transformer(args): config = TRANSFORMER_CONFIGS[args.transformer_type] with init_empty_weights(): transformer = HunyuanVideo15Transformer3DModel(**config) - state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict) + state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=transformer.config) transformer.load_state_dict(state_dict, strict=True, assign=True) return transformer diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py new file mode 100644 index 000000000000..88a9c8700a38 --- /dev/null +++ b/scripts/convert_ltx2_to_diffusers.py @@ -0,0 +1,902 @@ +import argparse +import os +from contextlib import nullcontext +from typing import Any, Dict, Optional, Tuple + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2LatentUpsamplePipeline, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + + +LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { + # Input Patchify Projections + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + # Modulation Parameters + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are + # substrings of the other modulation parameters below + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + # Transformer Blocks + # Per-Block Cross Attention Modulatin Parameters + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_0_AUDIO_VAE_RENAME_DICT = { + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +LTX_2_0_VOCODER_RENAME_DICT = { + "ups": "upsamplers", + "resblocks": "resnets", + "conv_pre": "conv_in", + "conv_post": "conv_out", +} + +LTX_2_0_TEXT_ENCODER_RENAME_DICT = { + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + +def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None: + state_dict.pop(key) + + +def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + if key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None: + if key.startswith("per_channel_statistics"): + new_key = ".".join(["decoder", key]) + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + +LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "video_embeddings_connector": remove_keys_inplace, + "audio_embeddings_connector": remove_keys_inplace, + "adaln_single": convert_ltx2_transformer_adaln_single, +} + +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, +} + +LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {} + +LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} + + +def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + connector_prefixes = ( + "video_embeddings_connector", + "audio_embeddings_connector", + "transformer_1d_blocks", + "text_embedding_projection.aggregate_embed", + "connectors.", + "video_connector", + "audio_connector", + "text_proj_in", + ) + + transformer_state_dict, connector_state_dict = {}, {} + for key, value in state_dict.items(): + if key.startswith(connector_prefixes): + connector_state_dict[key] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, connector_state_dict + + +def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + # Produces a transformer of the same size as used in test_models_transformer_ltx2.py + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 2, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 16, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1, + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 32, + "attention_head_dim": 128, + "cross_attention_dim": 4096, + "vae_scale_factors": (8, 32, 32), + "pos_embed_max_pos": 20, + "base_height": 2048, + "base_width": 2048, + "audio_in_channels": 128, + "audio_out_channels": 128, + "audio_patch_size": 1, + "audio_patch_size_t": 1, + "audio_num_attention_heads": 32, + "audio_attention_head_dim": 64, + "audio_cross_attention_dim": 2048, + "audio_scale_factor": 4, + "audio_pos_embed_max_pos": 20, + "audio_sampling_rate": 16000, + "audio_hop_length": 160, + "num_layers": 48, + "activation_fn": "gelu-approximate", + "qk_norm": "rms_norm_across_heads", + "norm_elementwise_affine": False, + "norm_eps": 1e-6, + "caption_channels": 3840, + "attention_bias": True, + "attention_out_bias": True, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_offset": 1, + "timestep_scale_multiplier": 1000, + "cross_attn_timestep_scale_multiplier": 1000, + "rope_type": "split", + }, + } + rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT + special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "caption_channels": 16, + "text_proj_in_factor": 3, + "video_connector_num_attention_heads": 4, + "video_connector_attention_head_dim": 8, + "video_connector_num_layers": 1, + "video_connector_num_learnable_registers": None, + "audio_connector_num_attention_heads": 4, + "audio_connector_attention_head_dim": 8, + "audio_connector_num_layers": 1, + "audio_connector_num_learnable_registers": None, + "connector_rope_base_seq_len": 32, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_temporal_positioning": False, + }, + } + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 30, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 2, + "video_connector_num_learnable_registers": 128, + "audio_connector_num_attention_heads": 30, + "audio_connector_attention_head_dim": 128, + "audio_connector_num_layers": 2, + "audio_connector_num_learnable_registers": 128, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + "rope_type": "split", + }, + } + + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = {} + + return config, rename_dict, special_keys_remap + + +def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version) + diffusers_config = config["diffusers_config"] + + transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict) + + with init_empty_weights(): + transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(transformer_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(transformer_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(transformer_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, transformer_state_dict) + + transformer.load_state_dict(transformer_state_dict, strict=True, assign=True) + return transformer + + +def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors: + config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version) + diffusers_config = config["diffusers_config"] + + _, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict) + if len(connector_state_dict) == 0: + raise ValueError("No connector weights found in the provided state dict.") + + with init_empty_weights(): + connectors = LTX2TextConnectors.from_config(diffusers_config) + + for key in list(connector_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(connector_state_dict, key, new_key) + + for key in list(connector_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, connector_state_dict) + + connectors.load_state_dict(connector_state_dict, strict=True, assign=True) + return connectors + + +def get_ltx2_video_vae_config( + version: str, timestep_conditioning: bool = False +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": timestep_conditioning, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (256, 512, 1024, 2048), + "down_block_types": ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 6, 6, 2, 2), + "decoder_layers_per_block": (5, 5, 5, 5), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": timestep_conditioning, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", + "spatial_compression_ratio": 32, + "temporal_compression_ratio": 8, + }, + } + rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_video_vae( + original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool +) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Video.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "base_channels": 128, + "output_channels": 2, + "ch_mult": (1, 2, 4), + "num_res_blocks": 2, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 256, + "latent_channels": 8, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 64, + "double_z": True, + }, + } + rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT + special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vae = AutoencoderKLLTX2Audio.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True, assign=True) + return vae + + +def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "in_channels": 128, + "hidden_channels": 1024, + "out_channels": 2, + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_factors": [6, 5, 2, 2, 2], + "resnet_kernel_sizes": [3, 7, 11], + "resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 24000, + }, + } + rename_dict = LTX_2_0_VOCODER_RENAME_DICT + special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP + return config, rename_dict, special_keys_remap + + +def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: + config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version) + diffusers_config = config["diffusers_config"] + + with init_empty_weights(): + vocoder = LTX2Vocoder.from_config(diffusers_config) + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vocoder.load_state_dict(original_state_dict, strict=True, assign=True) + return vocoder + + +def get_ltx2_spatial_latent_upsampler_config(version: str): + if version == "2.0": + config = { + "in_channels": 128, + "mid_channels": 1024, + "num_blocks_per_stage": 4, + "dims": 3, + "spatial_upsample": True, + "temporal_upsample": False, + "rational_spatial_scale": 2.0, + } + else: + raise ValueError(f"Unsupported version: {version}") + return config + + +def convert_ltx2_spatial_latent_upsampler( + original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype +): + with init_empty_weights(): + latent_upsampler = LTX2LatentUpsamplerModel(**config) + + latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True) + latent_upsampler.to(dtype) + return latent_upsampler + + +def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]: + if repo_id is None and filename is None: + raise ValueError("Please supply at least one of `repo_id` or `filename`") + + if repo_id is not None: + if filename is None: + raise ValueError("If repo_id is specified, filename must also be specified.") + ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) + else: + ckpt_path = filename + + _, ext = os.path.splitext(ckpt_path) + if ext in [".safetensors", ".sft"]: + state_dict = safetensors.torch.load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + return state_dict + + +def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]: + # Ensure that the key prefix ends with a dot (.) + if not prefix.endswith("."): + prefix = prefix + "." + + model_state_dict = {} + for param_name, param in combined_ckpt.items(): + if param_name.startswith(prefix): + model_state_dict[param_name.replace(prefix, "")] = param + + if prefix == "model.diffusion_model.": + # Some checkpoints store the text connector projection outside the diffusion model prefix. + connector_key = "text_embedding_projection.aggregate_embed.weight" + if connector_key in combined_ckpt and connector_key not in model_state_dict: + model_state_dict[connector_key] = combined_ckpt[connector_key] + + return model_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + + def none_or_str(value: str): + if isinstance(value, str) and value.lower() == "none": + return None + return value + + parser.add_argument( + "--original_state_dict_repo_id", + default="Lightricks/LTX-2", + type=none_or_str, + help="HF Hub repo id with LTX 2.0 checkpoint", + ) + parser.add_argument( + "--checkpoint_path", + default=None, + type=str, + help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.", + ) + parser.add_argument( + "--version", + type=str, + default="2.0", + choices=["test", "2.0"], + help="Version of the LTX 2.0 model", + ) + + parser.add_argument( + "--combined_filename", + default="ltx-2-19b-dev.safetensors", + type=none_or_str, + help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)", + ) + parser.add_argument("--vae_prefix", default="vae.", type=str) + parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str) + parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str) + parser.add_argument("--vocoder_prefix", default="vocoder.", type=str) + + parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set") + parser.add_argument( + "--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set" + ) + parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set") + parser.add_argument( + "--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set" + ) + parser.add_argument( + "--text_encoder_model_id", + default="google/gemma-3-12b-it-qat-q4_0-unquantized", + type=none_or_str, + help="HF Hub id for the LTX 2.0 base text encoder model", + ) + parser.add_argument( + "--tokenizer_id", + default="google/gemma-3-12b-it-qat-q4_0-unquantized", + type=none_or_str, + help="HF Hub id for the LTX 2.0 text tokenizer", + ) + parser.add_argument( + "--latent_upsampler_filename", + default="ltx-2-spatial-upscaler-x2-1.0.safetensors", + type=none_or_str, + help="Latent upsampler filename", + ) + + parser.add_argument( + "--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model" + ) + parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") + parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") + parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") + parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model") + parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") + parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder") + parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler") + parser.add_argument( + "--full_pipeline", + action="store_true", + help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)", + ) + parser.add_argument( + "--upsample_pipeline", + action="store_true", + help="Whether to save a latent upsampling pipeline", + ) + + parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) + + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +VARIANT_MAPPING = { + "fp32": None, + "fp16": "fp16", + "bf16": "bf16", +} + + +def main(args): + vae_dtype = DTYPE_MAPPING[args.vae_dtype] + audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype] + dit_dtype = DTYPE_MAPPING[args.dit_dtype] + vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype] + text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype] + + combined_ckpt = None + load_combined_models = any( + [ + args.vae, + args.audio_vae, + args.dit, + args.vocoder, + args.text_encoder, + args.full_pipeline, + args.upsample_pipeline, + ] + ) + if args.combined_filename is not None and load_combined_models: + combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename) + + if args.vae or args.full_pipeline or args.upsample_pipeline: + if args.vae_filename is not None: + original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename) + elif combined_ckpt is not None: + original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix) + vae = convert_ltx2_video_vae( + original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning + ) + if not args.full_pipeline and not args.upsample_pipeline: + vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae")) + + if args.audio_vae or args.full_pipeline: + if args.audio_vae_filename is not None: + original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename) + elif combined_ckpt is not None: + original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix) + audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version) + if not args.full_pipeline: + audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae")) + + if args.dit or args.full_pipeline: + if args.dit_filename is not None: + original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) + elif combined_ckpt is not None: + original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) + if not args.full_pipeline: + transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) + + if args.connectors or args.full_pipeline: + if args.dit_filename is not None: + original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) + elif combined_ckpt is not None: + original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version) + if not args.full_pipeline: + connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors")) + + if args.vocoder or args.full_pipeline: + if args.vocoder_filename is not None: + original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename) + elif combined_ckpt is not None: + original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix) + vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version) + if not args.full_pipeline: + vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) + + if args.text_encoder or args.full_pipeline: + # text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id) + if not args.full_pipeline: + text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder")) + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id) + if not args.full_pipeline: + tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer")) + + if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline: + original_latent_upsampler_ckpt = load_hub_or_local_checkpoint( + repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename + ) + latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version) + latent_upsampler = convert_ltx2_spatial_latent_upsampler( + original_latent_upsampler_ckpt, + latent_upsampler_config, + dtype=vae_dtype, + ) + if not args.full_pipeline and not args.upsample_pipeline: + latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler")) + + if args.full_pipeline: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) + + pipe = LTX2Pipeline( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + ) + + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + if args.upsample_pipeline: + pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler) + + # Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline + pipe.save_pretrained( + os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB" + ) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/setup.py b/setup.py index 8d346ddfecca..d52d37787fdb 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ twine upload dist/* -r pypi 10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory. You can use the following - Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/sayakpaul/auto-release-notes-diffusers. + Space to fetch all the commits applicable for the release: https://huggingface.co/spacmes/sayakpaul/auto-release-notes-diffusers. It automatically fetches the correct tag and branch but also provides the option to configure them. `tag` should be the previous release tag (v0.26.1, for example), and `branch` should be the latest release branch (v0.27.0-release, for example). It denotes all commits that have happened on branch @@ -274,7 +274,7 @@ def run(self): setup( name="diffusers", - version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="State-of-the-art diffusion in PyTorch and JAX.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eb8e86c4c89d..2024caf37d5c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.36.0.dev0" +__version__ = "0.37.0.dev0" from typing import TYPE_CHECKING @@ -23,6 +23,7 @@ is_torchao_available, is_torchsde_available, is_transformers_available, + is_transformers_version, ) @@ -167,12 +168,16 @@ "FirstBlockCacheConfig", "HookRegistry", "LayerSkipConfig", + "MagCacheConfig", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", + "TaylorSeerCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", + "apply_mag_cache", "apply_pyramid_attention_broadcast", + "apply_taylorseer_cache", ] ) _import_structure["models"].extend( @@ -191,6 +196,8 @@ "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", "AutoencoderKLHunyuanVideo15", + "AutoencoderKLLTX2Audio", + "AutoencoderKLLTX2Video", "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", @@ -221,6 +228,7 @@ "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", + "GlmImageTransformer2DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", @@ -233,6 +241,8 @@ "Kandinsky3UNet", "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", + "LongCatImageTransformer2DModel", + "LTX2VideoTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", @@ -276,6 +286,7 @@ "WanAnimateTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", + "ZImageControlNetModel", "ZImageTransformer2DModel", "attention_backend", ] @@ -349,6 +360,7 @@ "KDPM2AncestralDiscreteScheduler", "KDPM2DiscreteScheduler", "LCMScheduler", + "LTXEulerAncestralRFScheduler", "PNDMScheduler", "RePaintScheduler", "SASolverScheduler", @@ -402,6 +414,11 @@ else: _import_structure["modular_pipelines"].extend( [ + "Flux2AutoBlocks", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", + "Flux2KleinModularPipeline", + "Flux2ModularPipeline", "FluxAutoBlocks", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", @@ -411,12 +428,16 @@ "QwenImageEditModularPipeline", "QwenImageEditPlusAutoBlocks", "QwenImageEditPlusModularPipeline", + "QwenImageLayeredAutoBlocks", + "QwenImageLayeredModularPipeline", "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", "Wan22AutoBlocks", "WanAutoBlocks", "WanModularPipeline", + "ZImageAutoBlocks", + "ZImageModularPipeline", ] ) _import_structure["pipelines"].extend( @@ -441,9 +462,11 @@ "AuraFlowPipeline", "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", + "BriaFiboEditPipeline", "BriaFiboPipeline", "BriaPipeline", "ChromaImg2ImgPipeline", + "ChromaInpaintPipeline", "ChromaPipeline", "ChronoEditPipeline", "CLIPImageProjection", @@ -455,6 +478,7 @@ "CogView4ControlPipeline", "CogView4Pipeline", "ConsisIDPipeline", + "Cosmos2_5_PredictBasePipeline", "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "CosmosTextToWorldPipeline", @@ -463,6 +487,7 @@ "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", + "Flux2KleinPipeline", "Flux2Pipeline", "FluxControlImg2ImgPipeline", "FluxControlInpaintPipeline", @@ -477,6 +502,7 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", + "GlmImagePipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -526,7 +552,13 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LongCatImageEditPipeline", + "LongCatImagePipeline", + "LTX2ImageToVideoPipeline", + "LTX2LatentUpsamplePipeline", + "LTX2Pipeline", "LTXConditionPipeline", + "LTXI2VLongMultiPromptPipeline", "LTXImageToVideoPipeline", "LTXLatentUpsamplePipeline", "LTXPipeline", @@ -555,6 +587,7 @@ "QwenImageEditPlusPipeline", "QwenImageImg2ImgPipeline", "QwenImageInpaintPipeline", + "QwenImageLayeredPipeline", "QwenImagePipeline", "ReduxImageEncoder", "SanaControlNetPipeline", @@ -660,6 +693,10 @@ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", + "ZImageControlNetInpaintPipeline", + "ZImageControlNetPipeline", + "ZImageImg2ImgPipeline", + "ZImageOmniPipeline", "ZImagePipeline", ] ) @@ -897,12 +934,16 @@ FirstBlockCacheConfig, HookRegistry, LayerSkipConfig, + MagCacheConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, + apply_mag_cache, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) from .models import ( AllegroTransformer3DModel, @@ -919,6 +960,8 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -949,6 +992,7 @@ FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, @@ -961,6 +1005,8 @@ Kandinsky3UNet, Kandinsky5Transformer3DModel, LatteTransformer3DModel, + LongCatImageTransformer2DModel, + LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, @@ -1003,6 +1049,7 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, + ZImageControlNetModel, ZImageTransformer2DModel, attention_backend, ) @@ -1068,6 +1115,7 @@ KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LCMScheduler, + LTXEulerAncestralRFScheduler, PNDMScheduler, RePaintScheduler, SASolverScheduler, @@ -1104,6 +1152,11 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .modular_pipelines import ( + Flux2AutoBlocks, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, + Flux2ModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, @@ -1113,12 +1166,16 @@ QwenImageEditModularPipeline, QwenImageEditPlusAutoBlocks, QwenImageEditPlusModularPipeline, + QwenImageLayeredAutoBlocks, + QwenImageLayeredModularPipeline, QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline, + ZImageAutoBlocks, + ZImageModularPipeline, ) from .pipelines import ( AllegroPipeline, @@ -1139,9 +1196,11 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, AuraFlowPipeline, + BriaFiboEditPipeline, BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, + ChromaInpaintPipeline, ChromaPipeline, ChronoEditPipeline, CLIPImageProjection, @@ -1153,6 +1212,7 @@ CogView4ControlPipeline, CogView4Pipeline, ConsisIDPipeline, + Cosmos2_5_PredictBasePipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, @@ -1161,6 +1221,7 @@ EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, + Flux2KleinPipeline, Flux2Pipeline, FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, @@ -1175,6 +1236,7 @@ FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, + GlmImagePipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, @@ -1224,7 +1286,13 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LongCatImageEditPipeline, + LongCatImagePipeline, + LTX2ImageToVideoPipeline, + LTX2LatentUpsamplePipeline, + LTX2Pipeline, LTXConditionPipeline, + LTXI2VLongMultiPromptPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline, @@ -1253,6 +1321,7 @@ QwenImageEditPlusPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, + QwenImageLayeredPipeline, QwenImagePipeline, ReduxImageEncoder, SanaControlNetPipeline, @@ -1356,6 +1425,10 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, + ZImageControlNetInpaintPipeline, + ZImageControlNetPipeline, + ZImageImg2ImgPipeline, + ZImageOmniPipeline, ZImagePipeline, ) diff --git a/src/diffusers/commands/fp16_safetensors.py b/src/diffusers/commands/fp16_safetensors.py index 41739261e553..382d6c39bd19 100644 --- a/src/diffusers/commands/fp16_safetensors.py +++ b/src/diffusers/commands/fp16_safetensors.py @@ -35,8 +35,8 @@ def conversion_command_factory(args: Namespace): if args.use_auth_token: warnings.warn( - "The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now" - " handled automatically if user is logged in." + "The `--use_auth_token` flag is deprecated and will be removed in a future version." + "Authentication is now handled automatically if the user is logged in." ) return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors) @@ -92,8 +92,8 @@ def run(self): pipeline_class = getattr(import_module("diffusers"), pipeline_class_name) self.logger.info(f"Pipeline class imported: {pipeline_class_name}.") - # Load the appropriate pipeline. We could have use `DiffusionPipeline` - # here, but just to avoid any rough edge cases. + # Load the appropriate pipeline. We could have used `DiffusionPipeline` + # here, but just to avoid potential edge cases. pipeline = pipeline_class.from_pretrained( self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32 ) diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 4e53c373c4f4..58ad0c211b64 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -25,6 +25,7 @@ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .frequency_decoupled_guidance import FrequencyDecoupledGuidance from .guider_utils import BaseGuidance + from .magnitude_aware_guidance import MagnitudeAwareGuidance from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance from .smoothed_energy_guidance import SmoothedEnergyGuidance diff --git a/src/diffusers/guiders/magnitude_aware_guidance.py b/src/diffusers/guiders/magnitude_aware_guidance.py new file mode 100644 index 000000000000..b81cf0d3a1f9 --- /dev/null +++ b/src/diffusers/guiders/magnitude_aware_guidance.py @@ -0,0 +1,159 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import register_to_config +from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg + + +if TYPE_CHECKING: + from ..modular_pipelines.modular_pipeline import BlockState + + +class MagnitudeAwareGuidance(BaseGuidance): + """ + Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442 + + Args: + guidance_scale (`float`, defaults to `10.0`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + alpha (`float`, defaults to `8.0`): + The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of + guidance scale when the magnitude of the guidance update is large. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + start (`float`, defaults to `0.0`): + The fraction of the total number of denoising steps after which guidance starts. + stop (`float`, defaults to `1.0`): + The fraction of the total number of denoising steps after which guidance stops. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + @register_to_config + def __init__( + self, + guidance_scale: float = 10.0, + alpha: float = 8.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + enabled: bool = True, + ): + super().__init__(start, stop, enabled) + + self.guidance_scale = guidance_scale + self.alpha = alpha + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch(data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: + pred = None + + if not self._is_mambo_g_enabled(): + pred = pred_cond + else: + pred = mambo_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.alpha, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + + @property + def is_conditional(self) -> bool: + return self._count_prepared == 1 + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_mambo_g_enabled(): + num_conditions += 1 + return num_conditions + + def _is_mambo_g_enabled(self) -> bool: + if not self._enabled: + return False + + is_within_range = True + if self._num_inference_steps is not None: + skip_start_step = int(self._start * self._num_inference_steps) + skip_stop_step = int(self._stop * self._num_inference_steps) + is_within_range = skip_start_step <= self._step < skip_stop_step + + is_close = False + if self.use_original_formulation: + is_close = math.isclose(self.guidance_scale, 0.0) + else: + is_close = math.isclose(self.guidance_scale, 1.0) + + return is_within_range and not is_close + + +def mambo_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + alpha: float = 8.0, + use_original_formulation: bool = False, +): + dim = list(range(1, len(pred_cond.shape))) + diff = pred_cond - pred_uncond + ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True) + guidance_scale_final = ( + guidance_scale * torch.exp(-alpha * ratio) + if use_original_formulation + else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio) + ) + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale_final * diff + + return pred diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea9966..23c8bc92b2f1 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -23,5 +23,7 @@ from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook + from .mag_cache import MagCacheConfig, apply_mag_cache from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig + from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index ca7934e5c313..f5dd1f8c7c4d 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -23,7 +23,13 @@ _ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) -_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ( + "blocks", + "transformer_blocks", + "single_transformer_blocks", + "layers", + "visual_transformer_blocks", +) _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index da7313cb4737..1cbc3a35d5b9 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -26,6 +26,7 @@ class AttentionProcessorMetadata: class TransformerBlockMetadata: return_hidden_states_index: int = None return_encoder_hidden_states_index: int = None + hidden_states_argument_name: str = "hidden_states" _cls: Type = None _cached_parameter_indices: Dict[str, int] = None @@ -169,7 +170,7 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): - from ..models.attention import BasicTransformerBlock + from ..models.attention import BasicTransformerBlock, JointTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock @@ -184,6 +185,7 @@ def _register_transformer_blocks_metadata(): HunyuanImageSingleTransformerBlock, HunyuanImageTransformerBlock, ) + from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock @@ -331,6 +333,24 @@ def _register_transformer_blocks_metadata(): ), ) + TransformerBlockRegistry.register( + model_class=JointTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock) + TransformerBlockRegistry.register( + model_class=Kandinsky5TransformerDecoderBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + hidden_states_argument_name="visual_embed", + ), + ) + # fmt: off def _skip_attention___ret___hidden_states(self, *args, **kwargs): diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6491d17b4f46..6f245d0befab 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import copy +import functools import inspect from dataclasses import dataclass -from typing import Dict, List, Type, Union +from typing import Dict, List, Tuple, Type, Union import torch +import torch.distributed as dist if torch.distributed.is_available(): @@ -27,9 +29,10 @@ ContextParallelInput, ContextParallelModelPlan, ContextParallelOutput, + gather_size_by_comm, ) from ..utils import get_logger -from ..utils.torch_utils import unwrap_module +from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module from .hooks import HookRegistry, ModelHook @@ -208,6 +211,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> ) return x else: + if self.parallel_config.ulysses_anything: + return PartitionAnythingSharder.shard_anything( + x, cp_input.split_dim, self.parallel_config._flattened_mesh + ) return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) @@ -233,7 +240,14 @@ def post_forward(self, module, output): for i, cpm in enumerate(self.metadata): if cpm is None: continue - output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + if self.parallel_config.ulysses_anything: + output[i] = PartitionAnythingSharder.unshard_anything( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) + else: + output[i] = EquipartitionSharder.unshard( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) return output[0] if is_tensor else tuple(output) @@ -274,6 +288,73 @@ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_ return tensor +class AllGatherAnythingFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh): + ctx.dim = dim + ctx.group = group + ctx.world_size = dist.get_world_size(group) + ctx.rank = dist.get_rank(group) + gathered_tensor = _all_gather_anything(tensor, dim, group) + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim) + return grad_splits[ctx.rank], None, None + + +class PartitionAnythingSharder: + @classmethod + def shard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + assert tensor.size()[dim] >= mesh.size(), ( + f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." + ) + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] + + @classmethod + def unshard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + tensor = tensor.contiguous() + tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) + return tensor + + +@functools.lru_cache(maxsize=64) +def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: + gather_shapes = [] + for i in range(world_size): + rank_shape = list(copy.deepcopy(shape)) + rank_shape[dim] = gather_dims[i] + gather_shapes.append(rank_shape) + return gather_shapes + + +@maybe_allow_in_graph +def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor: + world_size = dist.get_world_size(group=group) + + tensor = tensor.contiguous() + shape = tensor.shape + rank_dim = shape[dim] + gather_dims = gather_size_by_comm(rank_dim, group) + + gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size) + + gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes] + + dist.all_gather(gathered_tensors, tensor, group=group) + gathered_tensor = torch.cat(gathered_tensors, dim=dim) + return gathered_tensor + + def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]: if name.count("*") > 1: raise ValueError("Wildcard '*' can only be used once in the name") diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 11b8dfd15222..47f1f4199615 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -15,7 +15,7 @@ import hashlib import os from contextlib import contextmanager, nullcontext -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union @@ -59,6 +59,9 @@ class GroupOffloadingConfig: num_blocks_per_group: Optional[int] = None offload_to_disk_path: Optional[str] = None stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + block_modules: Optional[List[str]] = None + exclude_kwargs: Optional[List[str]] = None + module_prefix: Optional[str] = "" class ModuleGroup: @@ -77,7 +80,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - group_id: Optional[int] = None, + group_id: Optional[Union[int, str]] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -322,7 +325,21 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): self.group.stream.synchronize() args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) - kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + + # Some Autoencoder models use a feature cache that is passed through submodules + # and modified in place. The `send_to_device` call returns a copy of this feature cache object + # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features + exclude_kwargs = self.config.exclude_kwargs or [] + if exclude_kwargs: + moved_kwargs = send_to_device( + {k: v for k, v in kwargs.items() if k not in exclude_kwargs}, + self.group.onload_device, + non_blocking=self.group.non_blocking, + ) + kwargs.update(moved_kwargs) + else: + kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + return args, kwargs def post_forward(self, module: torch.nn.Module, output): @@ -455,6 +472,8 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[List[str]] = None, + exclude_kwargs: Optional[List[str]] = None, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -512,6 +531,13 @@ def apply_group_offloading( If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + block_modules (`List[str]`, *optional*): + List of module names that should be treated as blocks for offloading. If provided, only these modules will + be considered for block-level offloading. If not provided, the default block detection logic will be used. + exclude_kwargs (`List[str]`, *optional*): + List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like + caching lists that need to maintain their object identity across forward passes. If not provided, will be + inferred from the module's `_skip_keys` attribute if it exists. Example: ```python @@ -553,6 +579,12 @@ def apply_group_offloading( _raise_error_if_accelerate_model_or_sequential_hook_present(module) + if block_modules is None: + block_modules = getattr(module, "_group_offload_block_modules", None) + + if exclude_kwargs is None: + exclude_kwargs = getattr(module, "_skip_keys", None) + config = GroupOffloadingConfig( onload_device=onload_device, offload_device=offload_device, @@ -563,6 +595,8 @@ def apply_group_offloading( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + exclude_kwargs=exclude_kwargs, ) _apply_group_offloading(module, config) @@ -578,46 +612,66 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" - This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to - the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - """ + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly + defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is + done at the top-level blocks and modules specified in block_modules. + When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified + module, recursively apply block offloading to it. + """ if config.stream is not None and config.num_blocks_per_group != 1: logger.warning( f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." ) config.num_blocks_per_group = 1 - # Create module groups for ModuleList and Sequential blocks + block_modules = set(config.block_modules) if config.block_modules is not None else set() + + # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules modules_with_group_offloading = set() unmatched_modules = [] matched_module_groups = [] + for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - unmatched_modules.append((name, submodule)) + # Check if this is an explicitly defined block module + if name in block_modules: + # Track submodule using a prefix to avoid filename collisions during disk offload. + # Without this, submodules sharing the same model class would be assigned identical + # filenames (derived from the class name). + prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}." + submodule_config = replace(config, module_prefix=prefix) + + _apply_group_offloading_block_level(submodule, submodule_config) modules_with_group_offloading.add(name) - continue - for i in range(0, len(submodule), config.num_blocks_per_group): - current_modules = submodule[i : i + config.num_blocks_per_group] - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" - group = ModuleGroup( - modules=current_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=current_modules[-1], - onload_leader=current_modules[0], - non_blocking=config.non_blocking, - stream=config.stream, - record_stream=config.record_stream, - low_cpu_mem_usage=config.low_cpu_mem_usage, - onload_self=True, - group_id=group_id, - ) - matched_module_groups.append(group) - for j in range(i, i + len(current_modules)): - modules_with_group_offloading.add(f"{name}.{j}") + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Handle ModuleList and Sequential blocks as before + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue + + group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}" + group = ModuleGroup( + modules=current_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=group_id, + ) + matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + # This is an unmatched module + unmatched_modules.append((name, submodule)) # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): @@ -632,28 +686,29 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf parameters = [param for _, param in parameters] buffers = [buffer for _, buffer in buffers] - # Create a group for the unmatched submodules of the top-level module so that they are on the correct - # device when the forward pass is called. + # Create a group for the remaining unmatched submodules of the top-level + # module so that they are on the correct device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - unmatched_group = ModuleGroup( - modules=unmatched_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=module, - onload_leader=module, - parameters=parameters, - buffers=buffers, - non_blocking=False, - stream=None, - record_stream=False, - onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", - ) - if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, config=config) - else: - _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + unmatched_group = ModuleGroup( + modules=unmatched_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=module, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + record_stream=False, + onload_self=True, + group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group", + ) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, config=config) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py new file mode 100644 index 000000000000..d28cd2d793b6 --- /dev/null +++ b/src/diffusers/hooks/mag_cache.py @@ -0,0 +1,468 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" +_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" + +# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience. +# Users must explicitly pass these to the config if using Flux. +# Reference: https://github.com/Zehong-Ma/MagCache +FLUX_MAG_RATIOS = torch.tensor( + [1.0] + + [ + 1.21094, + 1.11719, + 1.07812, + 1.0625, + 1.03906, + 1.03125, + 1.03906, + 1.02344, + 1.03125, + 1.02344, + 0.98047, + 1.01562, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.0, + 0.99609, + 0.99609, + 0.98047, + 0.98828, + 0.96484, + 0.95703, + 0.93359, + 0.89062, + ] +) + + +def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: + """ + Interpolate the source array to the target length using nearest neighbor interpolation. + """ + src_length = len(src_array) + if target_length == 1: + return src_array[-1:] + + scale = (src_length - 1) / (target_length - 1) + grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) + mapped_indices = torch.round(grid * scale).long() + return src_array[mapped_indices] + + +@dataclass +class MagCacheConfig: + r""" + Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache). + + Args: + threshold (`float`, defaults to `0.06`): + The threshold for the accumulated error. If the accumulated error is below this threshold, the block + computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade + quality. + max_skip_steps (`int`, defaults to `3`): + The maximum number of consecutive steps that can be skipped (K in the paper). + retention_ratio (`float`, defaults to `0.2`): + The fraction of initial steps during which skipping is disabled to ensure stability. For example, if + `num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped. + num_inference_steps (`int`, defaults to `28`): + The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly. + mag_ratios (`torch.Tensor`, *optional*): + The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must + set `calibrate=True` to calculate them for your specific model. For Flux models, you can use + `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`. + calibrate (`bool`, defaults to `False`): + If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the + magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new + models or schedulers. + """ + + threshold: float = 0.06 + max_skip_steps: int = 3 + retention_ratio: float = 0.2 + num_inference_steps: int = 28 + mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None + calibrate: bool = False + + def __post_init__(self): + # User MUST provide ratios OR enable calibration. + if self.mag_ratios is None and not self.calibrate: + raise ValueError( + " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n" + "To get them for your model:\n" + "1. Initialize `MagCacheConfig(calibrate=True, ...)`\n" + "2. Run inference on your model once.\n" + "3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n" + "For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`." + ) + + if not self.calibrate and self.mag_ratios is not None: + if not torch.is_tensor(self.mag_ratios): + self.mag_ratios = torch.tensor(self.mag_ratios) + + if len(self.mag_ratios) != self.num_inference_steps: + logger.debug( + f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" + ) + self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps) + + +class MagCacheState(BaseState): + def __init__(self) -> None: + super().__init__() + # Cache for the residual (output - input) from the *previous* timestep + self.previous_residual: torch.Tensor = None + + # State inputs/outputs for the current forward pass + self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + # MagCache accumulators + self.accumulated_ratio: float = 1.0 + self.accumulated_err: float = 0.0 + self.accumulated_steps: int = 0 + + # Current step counter (timestep index) + self.step_index: int = 0 + + # Calibration storage + self.calibration_ratios: List[float] = [] + + def reset(self): + self.previous_residual = None + self.should_compute = True + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + self.step_index = 0 + self.calibration_ratios = [] + + +class MagCacheHeadHook(ModelHook): + _is_stateful = True + + def __init__(self, state_manager: StateManager, config: MagCacheConfig): + self.state_manager = state_manager + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + + arg_name = self._metadata.hidden_states_argument_name + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + state: MagCacheState = self.state_manager.get_state() + state.head_block_input = hidden_states + + should_compute = True + + if self.config.calibrate: + # Never skip during calibration + should_compute = True + else: + # MagCache Logic + current_step = state.step_index + if current_step >= len(self.config.mag_ratios): + current_scale = 1.0 + else: + current_scale = self.config.mag_ratios[current_step] + + retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) + + if current_step >= retention_step: + state.accumulated_ratio *= current_scale + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + if ( + state.previous_residual is not None + and state.accumulated_err <= self.config.threshold + and state.accumulated_steps <= self.config.max_skip_steps + ): + should_compute = False + else: + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + + state.should_compute = should_compute + + if not should_compute: + logger.debug(f"MagCache: Skipping step {state.step_index}") + # Apply MagCache: Output = Input + Previous Residual + + output = hidden_states + res = state.previous_residual + + if res.device != output.device: + res = res.to(output.device) + + # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only) + if res.shape == output.shape: + output = output + res + elif ( + output.ndim == 3 + and res.ndim == 3 + and output.shape[0] == res.shape[0] + and output.shape[2] == res.shape[2] + ): + # Assuming concatenation where image part is at the end (standard in Flux/SD3) + diff = output.shape[1] - res.shape[1] + if diff > 0: + output = output.clone() + output[:, diff:, :] = output[:, diff:, :] + res + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return tuple(ret_list) + else: + return output + + else: + # Compute original forward + output = self.fn_ref.original_forward(*args, **kwargs) + return output + + def reset_state(self, module): + self.state_manager.reset() + return module + + +class MagCacheBlockHook(ModelHook): + def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + @torch.compiler.disable + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + state: MagCacheState = self.state_manager.get_state() + + if not state.should_compute: + arg_name = self._metadata.hidden_states_argument_name + hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) + + if self.is_tail: + # Still need to advance step index even if we skip + self._advance_step(state) + + if self._metadata.return_encoder_hidden_states_index is not None: + encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = hidden_states + ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + return tuple(ret_list) + + return hidden_states + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_tail: + # Calculate residual for next steps + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + if in_hidden is None: + return output + + # Determine residual + if out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: + diff = in_hidden.shape[1] - out_hidden.shape[1] + if diff == 0: + residual = out_hidden - in_hidden + else: + residual = out_hidden - in_hidden # Fallback to matching tail + else: + # Fallback for completely mismatched shapes + residual = out_hidden + + if self.config.calibrate: + self._perform_calibration_step(state, residual) + + state.previous_residual = residual + self._advance_step(state) + + return output + + def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor): + if state.previous_residual is None: + # First step has no previous residual to compare against. + # log 1.0 as a neutral starting point. + ratio = 1.0 + else: + # MagCache Calibration Formula: mean(norm(curr) / norm(prev)) + # norm(dim=-1) gives magnitude of each token vector + curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) + prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1) + + # Avoid division by zero + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + + state.calibration_ratios.append(ratio) + + def _advance_step(self, state: MagCacheState): + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + # End of inference loop + if self.config.calibrate: + print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):") + print(f"{state.calibration_ratios}\n") + logger.info(f"MagCache Calibration Results: {state.calibration_ratios}") + + # Reset state + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + + +def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: + """ + Applies MagCache to a given module (typically a Transformer). + + Args: + module (`torch.nn.Module`): + The module to apply MagCache to. + config (`MagCacheConfig`): + The configuration for MagCache. + """ + # Initialize registry on the root module so the Pipeline can set context. + HookRegistry.check_if_exists_or_initialize(module) + + state_manager = StateManager(MagCacheState, (), {}) + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + if not remaining_blocks: + logger.warning("MagCache: No transformer blocks found to apply hooks.") + return + + # Handle single-block models + if len(remaining_blocks) == 1: + name, block = remaining_blocks[0] + logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") + _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True) + _apply_mag_cache_head_hook(block, state_manager, config) + return + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.info(f"MagCache: Applying Head Hook to {head_block_name}") + _apply_mag_cache_head_hook(head_block, state_manager, config) + + for name, block in remaining_blocks: + _apply_mag_cache_block_hook(block, state_manager, config) + + logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}") + _apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True) + + +def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application (e.g. switching modes) + if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) + + hook = MagCacheHeadHook(state_manager, config) + registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) + + +def _apply_mag_cache_block_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + is_tail: bool = False, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application + if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) + + hook = MagCacheBlockHook(state_manager, is_tail, config) + registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py new file mode 100644 index 000000000000..7cad9f4fa161 --- /dev/null +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -0,0 +1,346 @@ +import math +import re +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from ..utils import logging +from .hooks import HookRegistry, ModelHook, StateManager + + +logger = logging.get_logger(__name__) +_TAYLORSEER_CACHE_HOOK = "taylorseer_cache" +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( + "^blocks.*attn", + "^transformer_blocks.*attn", + "^single_transformer_blocks.*attn", +) +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) +_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS +_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",) +_PROJ_OUT_IDENTIFIERS = ("^proj_out$",) + + +@dataclass +class TaylorSeerCacheConfig: + """ + Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923 + + Attributes: + cache_interval (`int`, defaults to `5`): + The interval between full computation steps. After a full computation, the cached (predicted) outputs are + reused for this many subsequent denoising steps before refreshing with a new full forward pass. + + disable_cache_before_step (`int`, defaults to `3`): + The denoising step index before which caching is disabled, meaning full computation is performed for the + initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During + these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this + step. + + disable_cache_after_step (`int`, *optional*, defaults to `None`): + The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run + full computations without predictions or state updates, ensuring accuracy in later stages if needed. + + max_order (`int`, defaults to `1`): + The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide + better approximations but increase computation and memory usage. + + taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): + Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may + affect stability; higher precision improves accuracy at the cost of more memory. + + skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode, + the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded + shape) during prediction steps to skip computation cheaply. + + cache_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where + outputs are approximated and cached for reuse. + + use_lite_mode (`bool`, *optional*, defaults to `False`): + Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for + skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom + `inactive_identifiers` or `active_identifiers`. + + Notes: + - Patterns are matched using `re.fullmatch` on the module name. + - If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked. + - If neither is provided, all attention-like modules are hooked by default. + + Example of inactive and active usage: + + ```py + def forward(x): + x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute + x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps + return x + ``` + """ + + cache_interval: int = 5 + disable_cache_before_step: int = 3 + disable_cache_after_step: Optional[int] = None + max_order: int = 1 + taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16 + skip_predict_identifiers: Optional[List[str]] = None + cache_identifiers: Optional[List[str]] = None + use_lite_mode: bool = False + + def __repr__(self) -> str: + return ( + "TaylorSeerCacheConfig(" + f"cache_interval={self.cache_interval}, " + f"disable_cache_before_step={self.disable_cache_before_step}, " + f"disable_cache_after_step={self.disable_cache_after_step}, " + f"max_order={self.max_order}, " + f"taylor_factors_dtype={self.taylor_factors_dtype}, " + f"skip_predict_identifiers={self.skip_predict_identifiers}, " + f"cache_identifiers={self.cache_identifiers}, " + f"use_lite_mode={self.use_lite_mode})" + ) + + +class TaylorSeerState: + def __init__( + self, + taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16, + max_order: int = 1, + is_inactive: bool = False, + ): + self.taylor_factors_dtype = taylor_factors_dtype + self.max_order = max_order + self.is_inactive = is_inactive + + self.module_dtypes: Tuple[torch.dtype, ...] = () + self.last_update_step: Optional[int] = None + self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {} + self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None + self.device: Optional[torch.device] = None + self.current_step: int = -1 + + def reset(self) -> None: + self.current_step = -1 + self.last_update_step = None + self.taylor_factors = {} + self.inactive_shapes = None + self.device = None + + def update( + self, + outputs: Tuple[torch.Tensor, ...], + ) -> None: + self.module_dtypes = tuple(output.dtype for output in outputs) + self.device = outputs[0].device + + if self.is_inactive: + self.inactive_shapes = tuple(output.shape for output in outputs) + else: + for i, features in enumerate(outputs): + new_factors: Dict[int, torch.Tensor] = {0: features} + is_first_update = self.last_update_step is None + if not is_first_update: + delta_step = self.current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for TaylorSeer update.") + + # Recursive divided differences up to max_order + prev_factors = self.taylor_factors.get(i, {}) + for j in range(self.max_order): + prev = prev_factors.get(j) + if prev is None: + break + new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step + self.taylor_factors[i] = { + order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items() + } + + self.last_update_step = self.current_step + + @torch.compiler.disable + def predict(self) -> List[torch.Tensor]: + if self.last_update_step is None: + raise ValueError("Cannot predict without prior initialization/update.") + + step_offset = self.current_step - self.last_update_step + + outputs = [] + if self.is_inactive: + if self.inactive_shapes is None: + raise ValueError("Inactive shapes not set during prediction.") + for i in range(len(self.module_dtypes)): + outputs.append( + torch.zeros( + self.inactive_shapes[i], + dtype=self.module_dtypes[i], + device=self.device, + ) + ) + else: + if not self.taylor_factors: + raise ValueError("Taylor factors empty during prediction.") + num_outputs = len(self.taylor_factors) + num_orders = len(self.taylor_factors[0]) + for i in range(num_outputs): + output_dtype = self.module_dtypes[i] + taylor_factors = self.taylor_factors[i] + output = torch.zeros_like(taylor_factors[0], dtype=output_dtype) + for order in range(num_orders): + coeff = (step_offset**order) / math.factorial(order) + factor = taylor_factors[order] + output = output + factor.to(output_dtype) * coeff + outputs.append(output) + return outputs + + +class TaylorSeerCacheHook(ModelHook): + _is_stateful = True + + def __init__( + self, + cache_interval: int, + disable_cache_before_step: int, + taylor_factors_dtype: torch.dtype, + state_manager: StateManager, + disable_cache_after_step: Optional[int] = None, + ): + super().__init__() + self.cache_interval = cache_interval + self.disable_cache_before_step = disable_cache_before_step + self.disable_cache_after_step = disable_cache_after_step + self.taylor_factors_dtype = taylor_factors_dtype + self.state_manager = state_manager + + def initialize_hook(self, module: torch.nn.Module): + return module + + def reset_state(self, module: torch.nn.Module) -> None: + """ + Reset state between sampling runs. + """ + self.state_manager.reset() + + @torch.compiler.disable + def _measure_should_compute(self) -> bool: + state: TaylorSeerState = self.state_manager.get_state() + state.current_step += 1 + current_step = state.current_step + is_warmup_phase = current_step < self.disable_cache_before_step + is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0 + is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step + should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase + return should_compute, state + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + should_compute, state = self._measure_should_compute() + if should_compute: + outputs = self.fn_ref.original_forward(*args, **kwargs) + wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs + state.update(wrapped_outputs) + return outputs + + outputs_list = state.predict() + return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list) + + +def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]: + """ + Resolve effective inactive and active pattern lists from config + templates. + """ + + inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None + active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None + + return inactive_patterns or [], active_patterns or [] + + +def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): + """ + Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet). + + This function hooks selected modules in the model to enable caching or skipping based on the provided + configuration, reducing redundant computations in diffusion denoising loops. + + Args: + module (torch.nn.Module): The model subtree to apply the hooks to. + config (TaylorSeerCacheConfig): Configuration for the cache. + + Example: + ```python + >>> import torch + >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig + + >>> pipe = FluxPipeline.from_pretrained( + ... "black-forest-labs/FLUX.1-dev", + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.to("cuda") + + >>> config = TaylorSeerCacheConfig( + ... cache_interval=5, + ... max_order=1, + ... disable_cache_before_step=3, + ... taylor_factors_dtype=torch.float32, + ... ) + >>> pipe.transformer.enable_cache(config) + ``` + """ + inactive_patterns, active_patterns = _resolve_patterns(config) + + active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS + + if config.use_lite_mode: + logger.info("Using TaylorSeer Lite variant for cache.") + active_patterns = _PROJ_OUT_IDENTIFIERS + inactive_patterns = _BLOCK_IDENTIFIERS + if config.skip_predict_identifiers or config.cache_identifiers: + logger.warning("Lite mode overrides user patterns.") + + for name, submodule in module.named_modules(): + matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns) + matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns) + if not (matches_inactive or matches_active): + continue + _apply_taylorseer_cache_hook( + module=submodule, + config=config, + is_inactive=matches_inactive, + ) + + +def _apply_taylorseer_cache_hook( + module: nn.Module, + config: TaylorSeerCacheConfig, + is_inactive: bool, +): + """ + Registers the TaylorSeer hook on the specified nn.Module. + + Args: + name: Name of the module. + module: The nn.Module to be hooked. + config: Cache configuration. + is_inactive: Whether this module should operate in "inactive" mode. + """ + state_manager = StateManager( + TaylorSeerState, + init_kwargs={ + "taylor_factors_dtype": config.taylor_factors_dtype, + "max_order": config.max_order, + "is_inactive": is_inactive, + }, + ) + + registry = HookRegistry.check_if_exists_or_initialize(module) + + hook = TaylorSeerCacheHook( + cache_interval=config.cache_interval, + disable_cache_before_step=config.disable_cache_before_step, + taylor_factors_dtype=config.taylor_factors_dtype, + disable_cache_after_step=config.disable_cache_after_step, + state_manager=state_manager, + ) + + registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index ace4e8543a1c..bdd4dbbcd4b5 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -67,6 +67,7 @@ def text_encoder_attn_modules(text_encoder): "SD3LoraLoaderMixin", "AuraFlowLoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", + "LTX2LoraLoaderMixin", "LTXVideoLoraLoaderMixin", "LoraLoaderMixin", "FluxLoraLoaderMixin", @@ -121,6 +122,7 @@ def text_encoder_attn_modules(text_encoder): HunyuanVideoLoraLoaderMixin, KandinskyLoraLoaderMixin, LoraLoaderMixin, + LTX2LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, Mochi1LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index f3c17cd729b8..9fc1a24eb994 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2140,6 +2140,54 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref return converted_state_dict +def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"): + # Remove the prefix + state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")} + converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()} + + if non_diffusers_prefix == "diffusion_model": + rename_dict = { + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + else: + rename_dict = {"aggregate_embed": "text_proj_in"} + + # Apply renaming + renamed_state_dict = {} + for key, value in converted_state_dict.items(): + new_key = key[:] + for old_pattern, new_pattern in rename_dict.items(): + new_key = new_key.replace(old_pattern, new_pattern) + renamed_state_dict[new_key] = value + + # Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed + final_state_dict = {} + for key, value in renamed_state_dict.items(): + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + final_state_dict[new_key] = value + elif key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + final_state_dict[new_key] = value + else: + final_state_dict[key] = value + + # Add transformer prefix + prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors" + final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()} + + return final_state_dict + + def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict): has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) if has_diffusion_model: @@ -2273,8 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict): prefix = "diffusion_model." original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()} - num_double_layers = 8 - num_single_layers = 48 + num_double_layers = 0 + num_single_layers = 0 + for key in original_state_dict.keys(): + if key.startswith("single_blocks."): + num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1) + elif key.startswith("double_blocks."): + num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1) + lora_keys = ("lora_A", "lora_B") attn_types = ("img_attn", "txt_attn") @@ -2417,6 +2471,17 @@ def convert_key(key: str) -> str: state_dict = {convert_key(k): v for k, v in state_dict.items()} + def normalize_out_key(k: str) -> str: + if ".to_out" in k: + return k + return re.sub( + r"\.out(?=\.(?:lora_down|lora_up)\.weight$|\.alpha$)", + ".to_out.0", + k, + ) + + state_dict = {normalize_out_key(k): v for k, v in state_dict.items()} + has_default = any("default." in k for k in state_dict) if has_default: state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()} diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index bcbe54649f89..24d1fd7b9308 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -48,6 +48,7 @@ _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, + _convert_non_diffusers_ltx2_lora_to_diffusers, _convert_non_diffusers_ltxv_lora_to_diffusers, _convert_non_diffusers_lumina2_lora_to_diffusers, _convert_non_diffusers_qwen_lora_to_diffusers, @@ -74,6 +75,7 @@ TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" TRANSFORMER_NAME = "transformer" +LTX2_CONNECTOR_NAME = "connectors" _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} @@ -212,7 +214,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -639,7 +641,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_unet( state_dict, @@ -1079,7 +1081,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1375,7 +1377,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -1487,7 +1489,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): Load LoRA layers into [`FluxTransformer2DModel`], [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). - Specific to [`StableDiffusion3Pipeline`]. + Specific to [`FluxPipeline`]. """ _lora_loadable_modules = ["transformer", "text_encoder"] @@ -1628,30 +1630,7 @@ def load_lora_weights( **kwargs, ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1680,7 +1659,7 @@ def load_lora_weights( ) if not (has_lora_keys or has_norm_keys): - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") transformer_lora_state_dict = { k: state_dict.get(k) @@ -2527,7 +2506,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2724,7 +2703,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -2927,7 +2906,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3034,6 +3013,233 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class LTX2LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`]. + """ + + _lora_loadable_modules = ["transformer", "connectors"] + transformer_name = TRANSFORMER_NAME + connectors_name = LTX2_CONNECTOR_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + final_state_dict = state_dict + is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) + has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict) + if is_non_diffusers_format: + final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict) + if has_connector: + connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers( + state_dict, "text_embedding_projection" + ) + final_state_dict.update(connectors_state_dict) + out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + transformer_peft_state_dict = { + k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.") + } + connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")} + self.load_lora_into_transformer( + transformer_peft_state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + if connectors_peft_state_dict: + self.load_lora_into_transformer( + connectors_peft_state_dict, + transformer=getattr(self, self.connectors_name) + if not hasattr(self, "connectors") + else self.connectors, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + prefix=self.connectors_name, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + prefix: str = "transformer", + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {prefix}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + prefix=prefix, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class SanaLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`]. @@ -3127,7 +3333,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3330,7 +3536,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3534,7 +3740,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -3651,44 +3857,17 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs, ): r""" - Return state dict for lora weights and the network alphas. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - A string, the *model id* of a pretrained model hosted on the Hub. - - A path to a *directory* containing the model weights. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository. - weight_name (`str`, *optional*, defaults to None): - Name of the serialized state dict file. - use_safetensors (`bool`, *optional*): - Whether to use safetensors for loading. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata. + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. """ - # Load the main state dict first which has the LoRA layers + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -3731,6 +3910,7 @@ def lora_state_dict( out = (state_dict, metadata) if return_lora_metadata else state_dict return out + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -3739,26 +3919,13 @@ def load_lora_weights( **kwargs, ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. - hotswap (`bool`, *optional*): - Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - kwargs (`dict`, *optional*): - See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) @@ -3773,9 +3940,8 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") - # Load LoRA into transformer self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, @@ -3787,6 +3953,7 @@ def load_lora_weights( ) @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer( cls, state_dict, @@ -3798,23 +3965,9 @@ def load_lora_into_transformer( metadata=None, ): """ - Load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. - transformer (`Kandinsky5Transformer3DModel`): - The transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights. - hotswap (`bool`, *optional*): - See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. """ - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) @@ -3832,6 +3985,7 @@ def load_lora_into_transformer( ) @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -3840,24 +3994,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, - transformer_lora_adapter_metadata=None, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the transformer and text encoders. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process. - save_function (`Callable`): - The function to use to save the state dictionary. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way. - transformer_lora_adapter_metadata: - LoRA adapter metadata associated with the transformer. + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. """ lora_layers = {} lora_metadata = {} @@ -3867,7 +4007,7 @@ def save_lora_weights( lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata if not lora_layers: - raise ValueError("You must pass at least one of `transformer_lora_layers`") + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") cls._save_lora_weights( save_directory=save_directory, @@ -3879,6 +4019,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -3888,25 +4029,7 @@ def fuse_lora( **kwargs, ): r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. - - Example: - ```py - from diffusers import Kandinsky5T2VPipeline - - pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") - pipeline.load_lora_weights("path/to/lora.safetensors") - pipeline.fuse_lora(lora_scale=0.7) - ``` + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. """ super().fuse_lora( components=components, @@ -3916,12 +4039,10 @@ def fuse_lora( **kwargs, ) + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" - Reverses the effect of [`pipe.fuse_lora()`]. - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. """ super().unfuse_lora(components=components, **kwargs) @@ -4073,7 +4194,7 @@ def load_lora_weights( ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4350,7 +4471,7 @@ def load_lora_weights( ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) if load_into_transformer_2: @@ -4570,7 +4691,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4773,7 +4894,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -4979,7 +5100,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5185,7 +5306,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, @@ -5388,7 +5509,7 @@ def load_lora_weights( is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") self.load_lora_into_transformer( state_dict, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3f8519bbfa32..6cdd66a0f2c4 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -27,6 +27,7 @@ MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, + convert_sai_sd_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -62,9 +63,12 @@ "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, "WanVACETransformer3DModel": lambda model_cls, weights: weights, "ChromaTransformer2DModel": lambda model_cls, weights: weights, + "ChronoEditTransformer3DModel": lambda model_cls, weights: weights, "QwenImageTransformer2DModel": lambda model_cls, weights: weights, "Flux2Transformer2DModel": lambda model_cls, weights: weights, "ZImageTransformer2DModel": lambda model_cls, weights: weights, + "LTX2VideoTransformer3DModel": lambda model_cls, weights: weights, + "LTX2TextConnectors": lambda model_cls, weights: weights, } @@ -232,6 +236,13 @@ def load_lora_adapter( if "lora_A" not in first_key: state_dict = convert_unet_state_dict_to_peft(state_dict) + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora + # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors + is_sai_sd_control_lora = "lora_controlnet" in state_dict + if is_sai_sd_control_lora: + state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) + rank = {} for key, val in state_dict.items(): # Cannot figure out rank from lora layers that don't have at least 2 dimensions. @@ -263,6 +274,14 @@ def load_lora_adapter( adapter_name=adapter_name, ) + # Adjust LoRA config for Control LoRA + if is_sai_sd_control_lora: + lora_config.lora_alpha = lora_config.r + lora_config.alpha_pattern = lora_config.rank_pattern + lora_config.bias = "all" + lora_config.modules_to_save = lora_config.exclude_modules + lora_config.exclude_modules = None + # None: @@ -3876,6 +3975,9 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) update_state_dict(converted_state_dict, key, new_key) + if "norm_final.weight" in converted_state_dict.keys(): + _ = converted_state_dict.pop("norm_final.weight") + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in # special_keys_remap for key in list(converted_state_dict.keys()): @@ -3885,3 +3987,175 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) handler_fn_inplace(key, converted_state_dict) return converted_state_dict + + +def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs): + if config["add_control_noise_refiner"] is None: + return checkpoint + elif config["add_control_noise_refiner"] == "control_noise_refiner": + return checkpoint + elif config["add_control_noise_refiner"] == "control_layers": + converted_state_dict = { + key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.") + } + return converted_state_dict + else: + raise ValueError("Unknown Z-Image Turbo ControlNet type.") + + +def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs): + LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = { + # Transformer prefix + "model.diffusion_model.": "", + # Input Patchify Projections + "patchify_proj": "proj_in", + "audio_patchify_proj": "audio_proj_in", + # Modulation Parameters + # Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are + # substrings of the other modulation parameters below + "av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift", + "av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate", + "av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift", + "av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate", + # Transformer Blocks + # Per-Block Cross Attention Modulation Parameters + "scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table", + "scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", + } + + def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + def remove_keys_inplace(key: str, state_dict) -> None: + state_dict.pop(key) + + def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None: + # Skip if not a weight, bias + if ".weight" not in key and ".bias" not in key: + return + + if key.startswith("adaln_single."): + new_key = key.replace("adaln_single.", "time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + if key.startswith("audio_adaln_single."): + new_key = key.replace("audio_adaln_single.", "audio_time_embed.") + param = state_dict.pop(key) + state_dict[new_key] = param + + return + + LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "video_embeddings_connector": remove_keys_inplace, + "audio_embeddings_connector": remove_keys_inplace, + "adaln_single": convert_ltx2_transformer_adaln_single, + } + + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + update_state_dict_inplace(converted_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs): + LTX_2_0_VIDEO_VAE_RENAME_DICT = { + # Video VAE prefix + "vae.": "", + # Encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.1", + "down_blocks.3": "down_blocks.1.downsamplers.0", + "down_blocks.4": "down_blocks.2", + "down_blocks.5": "down_blocks.2.downsamplers.0", + "down_blocks.6": "down_blocks.3", + "down_blocks.7": "down_blocks.3.downsamplers.0", + "down_blocks.8": "mid_block", + # Decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + # Common + # For all 3D ResNets + "res_blocks": "resnets", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + def remove_keys_inplace(key: str, state_dict) -> None: + state_dict.pop(key) + + LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_inplace, + "per_channel_statistics.mean-of-stds": remove_keys_inplace, + } + + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + update_state_dict_inplace(converted_state_dict, key, new_key) + + # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in + # special_keys_remap + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs): + LTX_2_0_AUDIO_VAE_RENAME_DICT = { + # Audio VAE prefix + "audio_vae.": "", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None: + state_dict[new_key] = state_dict.pop(old_key) + + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} + + # Handle official code --> diffusers key remapping via the remap dict + for key in list(converted_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + update_state_dict_inplace(converted_state_dict, key, new_key) + + return converted_state_dict diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index 63fc97ed431f..e4346700a7b8 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import Dict, List, Optional, Union import safetensors import torch from huggingface_hub.utils import validate_hf_hub_args +from tokenizers import Tokenizer as TokenizerFast from torch import nn from ..models.modeling_utils import load_state_dict @@ -547,23 +549,39 @@ def unload_textual_inversion( else: last_special_token_id = added_token_id - # Delete from tokenizer - for token_id, token_to_remove in zip(token_ids, tokens): - del tokenizer._added_tokens_decoder[token_id] - del tokenizer._added_tokens_encoder[token_to_remove] - - # Make all token ids sequential in tokenizer - key_id = 1 - for token_id in tokenizer.added_tokens_decoder: - if token_id > last_special_token_id and token_id > last_special_token_id + key_id: - token = tokenizer._added_tokens_decoder[token_id] - tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token + # Fast tokenizers (v5+) + if hasattr(tokenizer, "_tokenizer"): + # Fast tokenizers: serialize, filter tokens, reload + tokenizer_json = json.loads(tokenizer._tokenizer.to_str()) + new_id = last_special_token_id + 1 + filtered = [] + for tok in tokenizer_json.get("added_tokens", []): + if tok.get("content") in set(tokens): + continue + if not tok.get("special", False): + tok["id"] = new_id + new_id += 1 + filtered.append(tok) + tokenizer_json["added_tokens"] = filtered + tokenizer._tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json)) + else: + # Slow tokenizers + for token_id, token_to_remove in zip(token_ids, tokens): del tokenizer._added_tokens_decoder[token_id] - tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id - key_id += 1 - tokenizer._update_trie() - # set correct total vocab size after removing tokens - tokenizer._update_total_vocab_size() + del tokenizer._added_tokens_encoder[token_to_remove] + + key_id = 1 + for token_id in list(tokenizer.added_tokens_decoder.keys()): + if token_id > last_special_token_id and token_id > last_special_token_id + key_id: + token = tokenizer._added_tokens_decoder[token_id] + tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token + del tokenizer._added_tokens_decoder[token_id] + tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id + key_id += 1 + if hasattr(tokenizer, "_update_trie"): + tokenizer._update_trie() + if hasattr(tokenizer, "_update_total_vocab_size"): + tokenizer._update_total_vocab_size() # Delete from text encoder text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a55d..4d1db36a7352 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -41,6 +41,8 @@ _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] + _import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"] + _import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"] @@ -66,6 +68,7 @@ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] _import_structure["embeddings"] = ["ImageProjection"] @@ -95,13 +98,16 @@ _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] + _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] + _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] + _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] @@ -151,6 +157,8 @@ AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo15, + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -180,6 +188,7 @@ SD3MultiControlNetModel, SparseControlNetModel, UNetControlNetXSModel, + ZImageControlNetModel, ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin @@ -200,6 +209,7 @@ EasyAnimateTransformer3DModel, Flux2Transformer2DModel, FluxTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, @@ -208,6 +218,8 @@ HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, LatteTransformer3DModel, + LongCatImageTransformer2DModel, + LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, LuminaNextDiT2DModel, diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4eb520c796..db45159adfc9 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union import torch +import torch.distributed as dist from ..utils import get_logger @@ -67,6 +68,9 @@ class ContextParallelConfig: convert_to_fp32: bool = True # TODO: support alltoall rotate_method: Literal["allgather", "alltoall"] = "allgather" + # Whether to enable ulysses anything attention to support + # any sequence lengths and any head numbers. + ulysses_anything: bool = False _rank: int = None _world_size: int = None @@ -90,14 +94,15 @@ def __post_init__(self): ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if self.ring_degree > 1 and self.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." ) + if self.ulysses_anything: + if self.ulysses_degree == 1: + raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.") + if self.ring_degree > 1: + raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.") @property def mesh_shape(self) -> Tuple[int, int]: @@ -261,3 +266,39 @@ def __repr__(self): # # ContextParallelOutput: # specifies how to gather the input tensor in the post-forward hook in the layer it is attached to + + +# Below are utility functions for distributed communication in context parallelism. +def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: + r"""Gather the local size from all ranks. + size: int, local size return: List[int], list of size from all ranks + """ + # NOTE(Serving/CP Safety): + # Do NOT cache this collective result. + # + # In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL) + # may legitimately differ across ranks. If we cache based on the *local* `size`, + # different ranks can have different cache hit/miss patterns across time. + # + # That can lead to a catastrophic distributed hang: + # - some ranks hit cache and *skip* dist.all_gather() + # - other ranks miss cache and *enter* dist.all_gather() + # This mismatched collective participation will stall the process group and + # eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL + # timeouts in Ulysses attention). + world_size = dist.get_world_size(group=group) + # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead + comm_backends = str(dist.get_backend(group=group)) + # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") + gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() + gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] + dist.all_gather( + gathered_sizes, + torch.tensor([size], device=gather_device, dtype=torch.int64), + group=group, + ) + + gathered_sizes = [s[0].item() for s in gathered_sizes] + # NOTE: DON'T use tolist here due to graph break - Explanation: + # Backend compiler `inductor` failed with aten._local_scalar_dense.default + return gathered_sizes diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ffad94cc7f27..a94c0f5a7d43 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -21,6 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist +import torch.nn.functional as F if torch.distributed.is_available(): @@ -44,6 +46,8 @@ is_xformers_version, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ..utils.torch_utils import maybe_allow_in_graph +from ._modeling_parallel import gather_size_by_comm if TYPE_CHECKING: @@ -235,6 +239,10 @@ def decorator(func): def get_active_backend(cls): return cls._active_backend, cls._backends[cls._active_backend] + @classmethod + def set_active_backend(cls, backend: str): + cls._active_backend = backend + @classmethod def list_backends(cls): return list(cls._backends.keys()) @@ -294,12 +302,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke _maybe_download_kernel_for_backend(backend) old_backend = _AttentionBackendRegistry._active_backend - _AttentionBackendRegistry._active_backend = backend + _AttentionBackendRegistry.set_active_backend(backend) try: yield finally: - _AttentionBackendRegistry._active_backend = old_backend + _AttentionBackendRegistry.set_active_backend(old_backend) def dispatch_attention_fn( @@ -348,6 +356,7 @@ def dispatch_attention_fn( check(**kwargs) kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) @@ -868,6 +877,97 @@ def _cudnn_attention_backward_op( return grad_query, grad_key, grad_value +# https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15135 +# forward declaration: +# aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) +def _native_flash_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for native flash attention.") + + tensors_to_save = () + + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + tensors_to_save += (query, key, value) + + out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( + torch.ops.aten._scaled_dot_product_flash_attention( + query=query, + key=key, + value=value, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=False, + scale=scale, + ) + ) + + tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) + if _save_ctx: + ctx.save_for_backward(*tensors_to_save) + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.max_q = max_q + ctx.max_k = max_k + + out = out.transpose(1, 2).contiguous() + if lse is not None: + lse = lse.transpose(1, 2).contiguous() + return (out, lse) if return_lse else out + + +# https://github.com/pytorch/pytorch/blob/e33fa0ece36a93dbc8ff19b0251b8d99f8ae8668/aten/src/ATen/native/native_functions.yaml#L15153 +# backward declaration: +# aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) +def _native_flash_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors + + grad_out = grad_out.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( + grad_out, + query, + key, + value, + out, + logsumexp=lse, + philox_seed=philox_seed, + philox_offset=philox_offset, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=ctx.max_q, + max_k=ctx.max_k, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + ) + grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value)) + + return grad_query, grad_key, grad_value + + # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 def _flash_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, @@ -1015,6 +1115,51 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _npu_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if return_lse: + raise ValueError("NPU attention backend does not support setting `return_lse=True`.") + + out = npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + return out + + +# Not implemented yet. +def _npu_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") + + # ===== Context parallel ===== @@ -1041,6 +1186,251 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x +def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + """ + Perform dimension sharding / reassembly across processes using _all_to_all_single. + + This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or + head dimension flexibly by accepting scatter_idx and gather_idx. + + Args: + x (torch.Tensor): + Input tensor. Expected shapes: + - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) + - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) + scatter_idx (int) : + Dimension along which the tensor is partitioned before all-to-all. + gather_idx (int): + Dimension along which the output is reassembled after all-to-all. + group : + Distributed process group for the Ulysses group. + + Returns: + torch.Tensor: Tensor with globally exchanged dimensions. + - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) + - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) + """ + group_world_size = torch.distributed.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + # Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence + # dimension and scatters head dimension + batch_size, seq_len_local, num_heads, head_dim = x.shape + seq_len = seq_len_local * group_world_size + num_heads_local = num_heads // group_world_size + + # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D + x_temp = ( + x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) + .transpose(0, 2) + .contiguous() + ) + + if group_world_size > 1: + out = _all_to_all_single(x_temp, group=group) + else: + out = x_temp + # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D + out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() + out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) + return out + elif scatter_idx == 1 and gather_idx == 2: + # Used after ulysses sequence parallel in unified SP. gathers the head dimension + # scatters back the sequence dimension. + batch_size, seq_len, num_heads_local, head_dim = x.shape + num_heads = num_heads_local * group_world_size + seq_len_local = seq_len // group_world_size + + # B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = ( + x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) + .permute(1, 3, 2, 0, 4) + .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) + ) + + if group_world_size > 1: + output = _all_to_all_single(x_temp, group) + else: + output = x_temp + output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() + output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) + return output + else: + raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") + + +class SeqAllToAllDim(torch.autograd.Function): + """ + all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange + for more info. + """ + + @staticmethod + def forward(ctx, group, input, scatter_id=2, gather_id=1): + ctx.group = group + ctx.scatter_id = scatter_id + ctx.gather_id = gather_id + return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) + + @staticmethod + def backward(ctx, grad_outputs): + grad_input = SeqAllToAllDim.apply( + ctx.group, + grad_outputs, + ctx.gather_id, # reversed + ctx.scatter_id, # reversed + ) + return (None, grad_input, None, None) + + +# Below are helper functions to handle abritrary head num and abritrary sequence length for Ulysses Anything Attention. +def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded + tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD + """ + world_size = dist.get_world_size(group=group) + H_PAD = 0 + if H % world_size != 0: + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2. + # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14. + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_GLOBAL, H_LOCAL, D) + """ + rank = dist.get_rank(group=group) + world_size = dist.get_world_size(group=group) + # Only the last rank may have padding + if H_PAD > 0 and rank == world_size - 1: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], + padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD + """ + if H is None: + return x, 0 + + rank = dist.get_rank(group=group) + world_size = dist.get_world_size(group=group) + H_PAD = 0 + # Only the last rank may need padding + if H % world_size != 0: + # We need to broadcast H_PAD to all ranks to keep consistency + # in unpadding step later for all ranks. + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + if rank == world_size - 1: + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_LOCAL, H_GLOBAL, D) + """ + if H_PAD > 0: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: + # query: (B, S_LOCAL, H_GLOBAL, D) + assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" + extra_kwargs = {} + extra_kwargs["NUM_QO_HEAD"] = query.shape[2] + extra_kwargs["Q_S_LOCAL"] = query.shape[1] + # Add other kwargs if needed in future + return extra_kwargs + + +@maybe_allow_in_graph +def all_to_all_single_any_qkv_async( + x: torch.Tensor, group: dist.ProcessGroup, **kwargs +) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) + """ + world_size = dist.get_world_size(group=group) + B, S_LOCAL, H, D = x.shape + x, H_PAD = _maybe_pad_qkv_head(x, H, group) + H_LOCAL = (H + H_PAD) // world_size + # (world_size, S_LOCAL, B, H_LOCAL, D) + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + + input_split_sizes = [S_LOCAL] * world_size + # S_LOCAL maybe not equal for all ranks in dynamic shape case, + # since we don't know the actual shape before this timing, thus, + # we have to use all gather to collect the S_LOCAL first. + output_split_sizes = gather_size_by_comm(S_LOCAL, group) + x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D) + x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + # (S_GLOBAL, B, H_LOCAL, D) + # -> (B, S_GLOBAL, H_LOCAL, D) + x = x.permute(1, 0, 2, 3).contiguous() + x = _maybe_unpad_qkv_head(x, H_PAD, group) + return x + + return wait + + +@maybe_allow_in_graph +def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) + """ + # Assume H is provided in kwargs, since we can't infer H from x's shape. + # The padding logic needs H to determine if padding is necessary. + H = kwargs.get("NUM_QO_HEAD", None) + world_size = dist.get_world_size(group=group) + + x, H_PAD = _maybe_pad_o_head(x, H, group) + shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) + (B, S_GLOBAL, H_LOCAL, D) = shape + + # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] + # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] + + # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer + # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then, + # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] + + # b.tensor_split(4)[0].shape[1]) + + S_LOCAL = kwargs.get("Q_S_LOCAL") + input_split_sizes = gather_size_by_comm(S_LOCAL, group) + x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) + output_split_sizes = [S_LOCAL] * world_size + x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) + x = x.permute(2, 1, 0, 3, 4).contiguous() + x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) + x = _maybe_unpad_o_head(x, H_PAD, group) + return x + + return wait + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod def forward( @@ -1101,7 +1491,10 @@ def forward( out = out.to(torch.float32) lse = lse.to(torch.float32) - lse = lse.unsqueeze(-1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) @@ -1162,7 +1555,7 @@ def backward( grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): @@ -1257,7 +1650,145 @@ def backward( x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) ) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None + + +class TemplatedUlyssesAnythingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + **kwargs, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + group = ulysses_mesh.get_group() + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config + + metadata = ulysses_anything_metadata(query) + query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) + key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) + value_wait = all_to_all_single_any_qkv_async(value, group, **metadata) + + query = query_wait() # type: torch.Tensor + key = key_wait() # type: torch.Tensor + value = value_wait() # type: torch.Tensor + + out = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=False, # ulysses anything only support forward pass now. + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) + out_wait = all_to_all_single_any_o_async(out, group, **metadata) + + if return_lse: + # lse: (B, S_Q_GLOBAL, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) + lse_wait = all_to_all_single_any_o_async(lse, group, **metadata) + out = out_wait() # type: torch.Tensor + lse = lse_wait() # type: torch.Tensor + lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) + else: + out = out_wait() # type: torch.Tensor + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") + + +def _templated_unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + scatter_idx: int = 2, + gather_idx: int = 1, +): + """ + Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 + """ + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + # context_layer is of shape (B, S, H_LOCAL, D) + output = SeqAllToAllDim.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) + if return_lse: + # lse is of shape (B, S, H_LOCAL, 1) + # Refer to: + # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if is_torch_version("<", "2.9.0"): + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) + lse = lse.squeeze(-1) + return (output, lse) + return output def _templated_context_parallel_attention( @@ -1275,16 +1806,17 @@ def _templated_context_parallel_attention( backward_op, _parallel_config: Optional["ParallelConfig"] = None, ): - if attn_mask is not None: - raise ValueError("Attention mask is not yet supported for templated attention.") if is_causal: raise ValueError("Causal attention is not yet supported for templated attention.") if enable_gqa: raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1: - return TemplatedRingAttention.apply( + if ( + _parallel_config.context_parallel_config.ring_degree > 1 + and _parallel_config.context_parallel_config.ulysses_degree > 1 + ): + return _templated_unified_attention( query, key, value, @@ -1298,8 +1830,8 @@ def _templated_context_parallel_attention( backward_op, _parallel_config, ) - elif _parallel_config.context_parallel_config.ulysses_degree > 1: - return TemplatedUlyssesAttention.apply( + elif _parallel_config.context_parallel_config.ring_degree > 1: + return TemplatedRingAttention.apply( query, key, value, @@ -1313,6 +1845,38 @@ def _templated_context_parallel_attention( backward_op, _parallel_config, ) + elif _parallel_config.context_parallel_config.ulysses_degree > 1: + if _parallel_config.context_parallel_config.ulysses_anything: + # For Any sequence lengths and Any head num support + return TemplatedUlyssesAnythingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + else: + return TemplatedUlyssesAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) else: raise ValueError("Reaching this branch of code is unexpected. Please report a bug.") @@ -1329,6 +1893,7 @@ def _flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1336,6 +1901,9 @@ def _flash_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: lse = None + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") + if _parallel_config is None: out = flash_attn_func( q=query, @@ -1378,6 +1946,7 @@ def _flash_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1385,6 +1954,9 @@ def _flash_attention_hub( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: lse = None + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn out = func( q=query, @@ -1521,11 +2093,15 @@ def _flash_attention_3( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 3.") + out, lse = _wrapped_flash_attn_3( q=query, k=key, @@ -1545,6 +2121,7 @@ def _flash_attention_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, window_size: Tuple[int, int] = (-1, -1), @@ -1555,6 +2132,8 @@ def _flash_attention_3_hub( ) -> torch.Tensor: if _parallel_config: raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.") + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 3.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn out = func( @@ -1694,12 +2273,16 @@ def _aiter_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for aiter attention") + if not return_lse and torch.is_grad_enabled(): # aiter requires return_lse=True by assertion when gradients are enabled. out, lse, *_ = aiter_flash_attn_func( @@ -1790,6 +2373,43 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): return out +def _prepare_additive_attn_mask( + attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True +) -> torch.Tensor: + """ + Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA. + + This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks. + + Args: + attn_mask: 2D tensor [batch_size, seq_len_k] + - Boolean: True means attend, False means mask out + - Additive: 0.0 means attend, -inf means mask out + target_dtype: The dtype to convert the mask to (usually query.dtype) + reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting + + Returns: + Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if + reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. + """ + # Check if the mask is boolean or already additive + if attn_mask.dtype == torch.bool: + # Convert boolean to additive: True -> 0.0, False -> -inf + attn_mask = torch.where(attn_mask, 0.0, float("-inf")) + # Convert to target dtype + attn_mask = attn_mask.to(dtype=target_dtype) + else: + # Already additive mask - just ensure correct dtype + attn_mask = attn_mask.to(dtype=target_dtype) + + # Optionally reshape to 4D for broadcasting in attention mechanisms + if reshape_4d: + batch_size, seq_len_k = attn_mask.shape + attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) + + return attn_mask + + @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], @@ -1809,6 +2429,19 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") + + # Reshape 2D mask to 4D for SDPA + # SDPA accepts both boolean masks (torch.bool) and additive masks (float) + if ( + attn_mask is not None + and attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + # Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k] + # SDPA handles both boolean and additive masks correctly + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) + if _parallel_config is None: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) out = torch.nn.functional.scaled_dot_product_attention( @@ -1931,11 +2564,13 @@ def _native_efficient_attention( @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_FLASH, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _native_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1943,22 +2578,43 @@ def _native_flash_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - if return_lse: - raise ValueError("Native flash attention backend does not support setting `return_lse=True`.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=None, # not supported - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for aiter attention") + + lse = None + if _parallel_config is None and not return_lse: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, # not supported + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=_native_flash_attention_forward_op, + backward_op=_native_flash_attention_backward_op, + _parallel_config=_parallel_config, ) - out = out.permute(0, 2, 1, 3) - return out + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out @_AttentionBackendRegistry.register( @@ -1998,34 +2654,52 @@ def _native_math_attention( @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_NPU, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, ) def _native_npu_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for NPU attention") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") - query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) - out = npu_fusion_attention( - query, - key, - value, - query.size(1), # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0 - dropout_p, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() + if _parallel_config is None: + out = npu_fusion_attention( + query, + key, + value, + query.size(2), # num_heads + input_layout="BSND", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + None, + scale, + None, + return_lse, + forward_op=_npu_attention_forward_op, + backward_op=_npu_attention_backward_op, + _parallel_config=_parallel_config, + ) return out @@ -2038,10 +2712,13 @@ def _native_xla_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for XLA attention") if return_lse: raise ValueError("XLA attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) @@ -2065,11 +2742,14 @@ def _sage_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") lse = None if _parallel_config is None: out = sageattn( @@ -2113,11 +2793,14 @@ def _sage_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn if _parallel_config is None: @@ -2199,11 +2882,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda( q=query, k=key, @@ -2223,11 +2909,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, k=key, @@ -2247,11 +2936,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, @@ -2271,11 +2963,14 @@ def _sage_qk_int8_pv_fp16_triton_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_triton( q=query, k=key, @@ -2313,10 +3008,34 @@ def _xformers_attention( attn_mask = xops.LowerTriangularMask() elif attn_mask is not None: if attn_mask.ndim == 2: - attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + # Convert 2D mask to 4D for xformers + # Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask) + # xformers requires 4D additive masks [batch, heads, seq_q, seq_k] + # Need memory alignment - create larger tensor and slice for alignment + original_seq_len = attn_mask.size(1) + aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8 + + # Create aligned 4D tensor and slice to ensure proper memory layout + aligned_mask = torch.zeros( + (batch_size, num_heads_q, seq_len_q, aligned_seq_len), + dtype=query.dtype, + device=query.device, + ) + # Convert to 4D additive mask (handles both boolean and additive inputs) + mask_additive = _prepare_additive_attn_mask( + attn_mask, target_dtype=query.dtype + ) # [batch, 1, 1, seq_len_k] + # Broadcast to [batch, heads, seq_q, seq_len_k] + aligned_mask[:, :, :, :original_seq_len] = mask_additive + # Mask out the padding (already -inf from zeros -> where with default) + aligned_mask[:, :, :, original_seq_len:] = float("-inf") + + # Slice to actual size with proper alignment + attn_mask = aligned_mask[:, :, :, :seq_len_kv] elif attn_mask.ndim != 4: raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") - attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + elif attn_mask.ndim == 4: + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) if enable_gqa: if num_heads_q % num_heads_kv != 0: diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py index c96b4fa88c49..0a5b7fff1c66 100644 --- a/src/diffusers/models/auto_model.py +++ b/src/diffusers/models/auto_model.py @@ -18,7 +18,7 @@ from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin -from ..utils import logging +from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code @@ -220,4 +220,11 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") kwargs = {**load_config_kwargs, **kwargs} - return model_cls.from_pretrained(pretrained_model_or_path, **kwargs) + model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs) + + load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs} + parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS] + load_id = "|".join("null" if p is None else p for p in parts) + model._diffusers_load_id = load_id + + return model diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 56df27f93cd7..8e7a9c81d2ad 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo +from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video +from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index ec301ef8ad51..19b666fdc4a8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -102,14 +102,14 @@ def get_block( attention_head_dim: int, norm_type: str, act_fn: str, - qkv_mutliscales: Tuple[int, ...] = (), + qkv_multiscales: Tuple[int, ...] = (), ): if block_type == "ResBlock": block = ResBlock(in_channels, out_channels, norm_type, act_fn) elif block_type == "EfficientViTBlock": block = EfficientViTBlock( - in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales + in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales ) else: @@ -247,7 +247,7 @@ def __init__( attention_head_dim=attention_head_dim, norm_type="rms_norm", act_fn="silu", - qkv_mutliscales=qkv_multiscales[i], + qkv_multiscales=qkv_multiscales[i], ) down_block_list.append(block) @@ -339,7 +339,7 @@ def __init__( attention_head_dim=attention_head_dim, norm_type=norm_type[i], act_fn=act_fn[i], - qkv_mutliscales=qkv_multiscales[i], + qkv_multiscales=qkv_multiscales[i], ) up_block_list.append(block) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 265f2abcfba6..95991dca3304 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -74,6 +74,7 @@ class AutoencoderKL( _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py index 616d0d415840..c02f11bef40a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py @@ -27,7 +27,7 @@ from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -410,7 +410,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return h -class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model for 2D images with spatial tiling support. @@ -486,27 +486,6 @@ def enable_tiling( self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor): batch_size, num_channels, height, width = x.shape diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py index 2249063a9f00..973574e616bf 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py @@ -26,7 +26,7 @@ from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -584,7 +584,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin): +class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for HunyuanImage-2.1 Refiner. @@ -685,27 +685,6 @@ def enable_tiling( self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: _, _, _, height, width = x.shape diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py index 4b1beb74a3bc..c662c1657513 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -26,7 +26,7 @@ from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -625,7 +625,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin): +class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for HunyuanVideo-1.5. @@ -723,27 +723,6 @@ def enable_tiling( self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: _, _, _, height, width = x.shape diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py new file mode 100644 index 000000000000..01dd55a938b6 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -0,0 +1,1521 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values + across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward(self, x: torch.Tensor, channel_dim: Optional[int] = None) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = PerChannelRMSNorm() + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm2 = PerChannelRMSNorm() + self.dropout = nn.Dropout(dropout) + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, causal=causal) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, causal=causal) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTX2VideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`Tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal + + output_channel = out_channels + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + spatial_padding_mode=spatial_padding_mode, + ) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + causal = causal or self.is_causal + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states, causal=causal) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, causal=causal) + + hidden_states = self.mid_block(hidden_states, causal=causal) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTX2VideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + inject_noise: Tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) + else: + hidden_states = self.mid_block(hidden_states, temb, causal=causal) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, causal=causal) + + hidden_states = self.norm_out(hidden_states) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX-2](https://huggingface.co/Lightricks/LTX-2). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: Tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: Tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + decoder_block_out_channels: Tuple[int, ...] = (256, 512, 1024), + layers_per_block: Tuple[int, ...] = (4, 6, 6, 2, 2), + decoder_layers_per_block: Tuple[int, ...] = (5, 5, 5, 5), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False), + downsample_type: Tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: Tuple[bool, ...] = (True, True, True), + upsample_factor: Tuple[int, ...] = (2, 2, 2), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTX2VideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + self.decoder = LTX2VideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_min_num_frames: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + tile_sample_stride_num_frames: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x, causal=causal) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x, causal=causal) + + enc = self.encoder(x, causal=causal) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x, causal=causal) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + dec = self.decoder(z, temb, causal=causal) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice, causal=causal).sample + for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb, causal=causal).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile, causal=causal) + else: + tile = self.encoder(tile, causal=causal) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample + else: + decoded = self.decoder(tile, temb, causal=causal) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, + sample_posterior: bool = False, + encoder_causal: Optional[bool] = None, + decoder_causal: Optional[bool] = None, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x, causal=encoder_causal).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb, causal=decoder_causal) + if not return_dict: + return (dec.sample,) + return dec diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py new file mode 100644 index 000000000000..b29629a29f80 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py @@ -0,0 +1,804 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_type == "group": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.non_linearity = nn.SiLU() + if causality_axis is not None: + self.conv1 = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + if norm_type == "group": + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.dropout = nn.Dropout(dropout) + if causality_axis is not None: + self.conv2 = LTX2AudioCausalConv2d( + out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + if causality_axis is not None: + self.conv_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + if causality_axis is not None: + self.nin_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class LTX2AudioDownsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + if self.causality_axis == "none": + pad = (0, 1, 0, 1) + elif self.causality_axis == "width": + pad = (2, 0, 0, 1) + elif self.causality_axis == "height": + pad = (0, 1, 2, 0) + elif self.causality_axis == "width-compatibility": + pad = (1, 0, 0, 1) + else: + raise ValueError( + f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`," + f" and `width-compatibility`." + ) + + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # with_conv=False implies that causality_axis is "none" + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class LTX2AudioUpsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + if causality_axis is not None: + self.conv = LTX2AudioCausalConv2d( + in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor: + batch, time, _ = audio_latents.shape + return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> Tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioEncoder(nn.Module): + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + double_z: bool = True, + ): + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels + base_resolution = resolution + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution + + for level in range(self.num_resolutions): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != self.num_resolutions - 1: + stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis) + curr_res = curr_res // 2 + + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + final_block_channels = block_in + z_channels = 2 * latent_channels if double_z else latent_channels + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + self.non_linearity = nn.SiLU() + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states expected shape: (batch_size, channels, time, num_mel_bins) + hidden_states = self.conv_in(hidden_states) + + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx, block in enumerate(stage.block): + hidden_states = block(hidden_states, temb=None) + if stage.attn: + hidden_states = stage.attn[block_idx](hidden_states) + + if level != self.num_resolutions - 1 and hasattr(stage, "downsample"): + hidden_states = stage.downsample(hidden_states) + + hidden_states = self.mid.block_1(hidden_states, temb=None) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb=None) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: Tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: Optional[str] = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + ) -> None: + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + self.non_linearity = nn.SiLU() + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + self.up = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != 0: + stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis) + curr_res *= 2 + + self.up.insert(0, stage) + + final_block_channels = block_in + + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1) + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + _, _, frames, mel_bins = sample.shape + + target_frames = frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_channels = self.out_ch + target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins + + hidden_features = self.conv_in(sample) + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + hidden_features = block(hidden_features, temb=None) + if stage.attn: + hidden_features = stage.attn[block_idx](hidden_features) + + if level != 0 and hasattr(stage, "upsample"): + hidden_features = stage.upsample(hidden_features) + + if self.give_pre_end: + return hidden_features + + hidden = self.norm_out(hidden_features) + hidden = self.non_linearity(hidden) + decoded_output = self.conv_out(hidden) + decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output + + _, _, current_time, current_freq = decoded_output.shape + target_time = target_frames + target_freq = target_mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + +class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin): + r""" + LTX2 audio VAE for encoding and decoding audio latent representations. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_channels: int = 128, + output_channels: int = 2, + ch_mult: Tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: Optional[Tuple[int, ...]] = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + norm_type: str = "pixel", + causality_axis: Optional[str] = "height", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: Optional[int] = 64, + double_z: bool = True, + ) -> None: + super().__init__() + + supported_causality_axes = {"none", "width", "height", "width-compatibility"} + if causality_axis not in supported_causality_axes: + raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}") + + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions + + self.encoder = LTX2AudioEncoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + double_z=double_z, + ) + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + # the entire dataset and stored in model's checkpoint under AudioVAE state_dict + latents_std = torch.ones((base_channels,)) + latents_mean = torch.zeros((base_channels,)) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + # TODO: calculate programmatically instead of hardcoding + self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 + # TODO: confirm whether the mel compression ratio below is correct + self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True): + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec.sample,) + return dec diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 618801dfb605..7f7266146e6b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -394,6 +394,7 @@ def __init__( attn_scales=[], temperal_downsample=[True, True, False], dropout=0.0, + input_channels=3, non_linearity: str = "silu", ): super().__init__() @@ -410,7 +411,7 @@ def __init__( scale = 1.0 # init block - self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1) # downsample blocks self.down_blocks = nn.ModuleList([]) @@ -570,6 +571,7 @@ def __init__( attn_scales=[], temperal_upsample=[False, True, True], dropout=0.0, + input_channels=3, non_linearity: str = "silu", ): super().__init__() @@ -621,7 +623,7 @@ def __init__( # output blocks self.norm_out = QwenImageRMS_norm(out_dim, images=False) - self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1) self.gradient_checkpointing = False @@ -684,6 +686,7 @@ def __init__( attn_scales: List[float] = [], temperal_downsample: List[bool] = [False, True, True], dropout: float = 0.0, + input_channels: int = 3, latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], ) -> None: @@ -695,13 +698,13 @@ def __init__( self.temperal_upsample = temperal_downsample[::-1] self.encoder = QwenImageEncoder3d( - base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels ) self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) self.decoder = QwenImageDecoder3d( - base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels ) self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index b0b2960aaf18..761dff2dc61a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -619,6 +619,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): feat_idx[0] += 1 else: x = self.conv_out(x) + return x @@ -961,6 +962,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo """ _supports_gradient_checkpointing = False + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] # keys toignore when AlignDeviceHook moves inputs/outputs between devices # these are shared mutable state modified in-place _skip_keys = ["feat_cache", "feat_idx"] @@ -1259,14 +1261,20 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: `torch.Tensor`: The latent representation of the encoded videos. """ + _, _, num_frames, height, width = x.shape - latent_height = height // self.spatial_compression_ratio - latent_width = width // self.spatial_compression_ratio + encode_spatial_compression_ratio = self.spatial_compression_ratio + if self.config.patch_size is not None: + assert encode_spatial_compression_ratio % self.config.patch_size == 0 + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size - tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio - tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio - tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio - tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + latent_height = height // encode_spatial_compression_ratio + latent_width = width // encode_spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // encode_spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // encode_spatial_compression_ratio blend_height = tile_latent_min_height - tile_latent_stride_height blend_width = tile_latent_min_width - tile_latent_stride_width @@ -1408,6 +1416,7 @@ def forward( """ x = sample posterior = self.encode(x).latent_dist + if sample_posterior: z = posterior.sample(generator=generator) else: diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 605c0d588c8c..04c90668a1db 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -41,9 +41,11 @@ def enable_cache(self, config) -> None: Enable caching techniques on the model. Args: - config (`Union[PyramidAttentionBroadcastConfig]`): + config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`): The configuration for applying the caching technique. Currently supported caching techniques are: - [`~hooks.PyramidAttentionBroadcastConfig`] + - [`~hooks.FasterCacheConfig`] + - [`~hooks.FirstBlockCacheConfig`] Example: @@ -66,10 +68,14 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, FirstBlockCacheConfig, + MagCacheConfig, PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, + apply_mag_cache, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) if self.is_cache_enabled: @@ -81,18 +87,31 @@ def enable_cache(self, config) -> None: apply_faster_cache(self, config) elif isinstance(config, FirstBlockCacheConfig): apply_first_block_cache(self, config) + elif isinstance(config, MagCacheConfig): + apply_mag_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, TaylorSeerCacheConfig): + apply_taylorseer_cache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import ( + FasterCacheConfig, + FirstBlockCacheConfig, + HookRegistry, + MagCacheConfig, + PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, + ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK + from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK + from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -105,8 +124,13 @@ def disable_cache(self) -> None: elif isinstance(self._cache_config, FirstBlockCacheConfig): registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True) registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, MagCacheConfig): + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, TaylorSeerCacheConfig): + registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py deleted file mode 100644 index c18bd8751dcb..000000000000 --- a/src/diffusers/models/controlnet.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Tuple, Union - -from ..utils import deprecate -from .controlnets.controlnet import ( # noqa - ControlNetConditioningEmbedding, - ControlNetModel, - ControlNetOutput, - zero_module, -) - - -class ControlNetOutput(ControlNetOutput): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead." - deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message) - super().__init__(*args, **kwargs) - - -class ControlNetModel(ControlNetModel): - def __init__( - self, - in_channels: int = 4, - conditioning_channels: int = 3, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str, ...] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int, ...]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - projection_class_embeddings_input_dim: Optional[int] = None, - controlnet_conditioning_channel_order: str = "rgb", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - global_pool_conditions: bool = False, - addition_embed_type_num_heads: int = 64, - ): - deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead." - deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message) - super().__init__( - in_channels=in_channels, - conditioning_channels=conditioning_channels, - flip_sin_to_cos=flip_sin_to_cos, - freq_shift=freq_shift, - down_block_types=down_block_types, - mid_block_type=mid_block_type, - only_cross_attention=only_cross_attention, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - downsample_padding=downsample_padding, - mid_block_scale_factor=mid_block_scale_factor, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - encoder_hid_dim=encoder_hid_dim, - encoder_hid_dim_type=encoder_hid_dim_type, - attention_head_dim=attention_head_dim, - num_attention_heads=num_attention_heads, - use_linear_projection=use_linear_projection, - class_embed_type=class_embed_type, - addition_embed_type=addition_embed_type, - addition_time_embed_dim=addition_time_embed_dim, - num_class_embeds=num_class_embeds, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - global_pool_conditions=global_pool_conditions, - addition_embed_type_num_heads=addition_embed_type_num_heads, - ) - - -class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead." - deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message) - super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py deleted file mode 100644 index e82748436d86..000000000000 --- a/src/diffusers/models/controlnet_flux.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List - -from ..utils import deprecate, logging -from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class FluxControlNetOutput(FluxControlNetOutput): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead." - deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message) - super().__init__(*args, **kwargs) - - -class FluxControlNetModel(FluxControlNetModel): - def __init__( - self, - patch_size: int = 1, - in_channels: int = 64, - num_layers: int = 19, - num_single_layers: int = 38, - attention_head_dim: int = 128, - num_attention_heads: int = 24, - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, - axes_dims_rope: List[int] = [16, 56, 56], - num_mode: int = None, - conditioning_embedding_channels: int = None, - ): - deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead." - deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message) - super().__init__( - patch_size=patch_size, - in_channels=in_channels, - num_layers=num_layers, - num_single_layers=num_single_layers, - attention_head_dim=attention_head_dim, - num_attention_heads=num_attention_heads, - joint_attention_dim=joint_attention_dim, - pooled_projection_dim=pooled_projection_dim, - guidance_embeds=guidance_embeds, - axes_dims_rope=axes_dims_rope, - num_mode=num_mode, - conditioning_embedding_channels=conditioning_embedding_channels, - ) - - -class FluxMultiControlNetModel(FluxMultiControlNetModel): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead." - deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message) - super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py deleted file mode 100644 index d239ad4eb3e8..000000000000 --- a/src/diffusers/models/controlnet_sd3.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from ..utils import deprecate, logging -from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class SD3ControlNetOutput(SD3ControlNetOutput): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead." - deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message) - super().__init__(*args, **kwargs) - - -class SD3ControlNetModel(SD3ControlNetModel): - def __init__( - self, - sample_size: int = 128, - patch_size: int = 2, - in_channels: int = 16, - num_layers: int = 18, - attention_head_dim: int = 64, - num_attention_heads: int = 18, - joint_attention_dim: int = 4096, - caption_projection_dim: int = 1152, - pooled_projection_dim: int = 2048, - out_channels: int = 16, - pos_embed_max_size: int = 96, - extra_conditioning_channels: int = 0, - ): - deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead." - deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message) - super().__init__( - sample_size=sample_size, - patch_size=patch_size, - in_channels=in_channels, - num_layers=num_layers, - attention_head_dim=attention_head_dim, - num_attention_heads=num_attention_heads, - joint_attention_dim=joint_attention_dim, - caption_projection_dim=caption_projection_dim, - pooled_projection_dim=pooled_projection_dim, - out_channels=out_channels, - pos_embed_max_size=pos_embed_max_size, - extra_conditioning_channels=extra_conditioning_channels, - ) - - -class SD3MultiControlNetModel(SD3MultiControlNetModel): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead." - deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message) - super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py deleted file mode 100644 index 5c67af4fe9c1..000000000000 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Optional, Tuple, Union - -from ..utils import deprecate, logging -from .controlnets.controlnet_sparsectrl import ( # noqa - SparseControlNetConditioningEmbedding, - SparseControlNetModel, - SparseControlNetOutput, - zero_module, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class SparseControlNetOutput(SparseControlNetOutput): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead." - deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message) - super().__init__(*args, **kwargs) - - -class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding): - def __init__(self, *args, **kwargs): - deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead." - deprecate( - "diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message - ) - super().__init__(*args, **kwargs) - - -class SparseControlNetModel(SparseControlNetModel): - def __init__( - self, - in_channels: int = 4, - conditioning_channels: int = 4, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str, ...] = ( - "CrossAttnDownBlockMotion", - "CrossAttnDownBlockMotion", - "CrossAttnDownBlockMotion", - "DownBlockMotion", - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 768, - transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None, - temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, - attention_head_dim: Union[int, Tuple[int, ...]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, - use_linear_projection: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), - global_pool_conditions: bool = False, - controlnet_conditioning_channel_order: str = "rgb", - motion_max_seq_length: int = 32, - motion_num_attention_heads: int = 8, - concat_conditioning_mask: bool = True, - use_simplified_condition_embedding: bool = True, - ): - deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead." - deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message) - super().__init__( - in_channels=in_channels, - conditioning_channels=conditioning_channels, - flip_sin_to_cos=flip_sin_to_cos, - freq_shift=freq_shift, - down_block_types=down_block_types, - only_cross_attention=only_cross_attention, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - downsample_padding=downsample_padding, - mid_block_scale_factor=mid_block_scale_factor, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - norm_eps=norm_eps, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - transformer_layers_per_mid_block=transformer_layers_per_mid_block, - temporal_transformer_layers_per_block=temporal_transformer_layers_per_block, - attention_head_dim=attention_head_dim, - num_attention_heads=num_attention_heads, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - conditioning_embedding_out_channels=conditioning_embedding_out_channels, - global_pool_conditions=global_pool_conditions, - controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, - motion_max_seq_length=motion_max_seq_length, - motion_num_attention_heads=motion_num_attention_heads, - concat_conditioning_mask=concat_conditioning_mask, - use_simplified_condition_embedding=use_simplified_condition_embedding, - ) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index 7ce352879daa..fee7f231e899 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -19,6 +19,7 @@ ) from .controlnet_union import ControlNetUnionModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + from .controlnet_z_image import ZImageControlNetModel from .multicontrolnet import MultiControlNetModel from .multicontrolnet_union import MultiControlNetUnionModel diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 5c89c9267db4..0b5b9fa3efba 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -19,6 +19,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import BaseOutput, logging from ..attention import AttentionMixin @@ -106,7 +107,7 @@ def forward(self, conditioning): return embedding -class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin): +class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): """ A ControlNet model. diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 86971271788f..fa374285eec1 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -31,6 +31,7 @@ QwenImageTransformerBlock, QwenTimestepProjEmbeddings, RMSNorm, + compute_text_seq_len_from_mask, ) @@ -136,7 +137,7 @@ def forward( return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ - The [`FluxTransformer2DModel`] forward method. + The [`QwenImageControlNetModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): @@ -147,24 +148,39 @@ def forward( The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. - block_controlnet_hidden_states: (`list` of `torch.Tensor`): - A list of tensors that if specified are added to the residuals of transformer blocks. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*): + **Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence + length. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain - tuple. + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. + If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where + the first element is the controlnet block samples. """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in " + "version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` " + "and `encoder_hidden_states_mask`.", + standard_warn=False, + ) + if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -186,32 +202,47 @@ def forward( temb = self.time_text_embed(timestep, hidden_states) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) timestep = timestep.to(hidden_states.dtype) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Construct joint attention mask once to avoid reconstructing in every block + block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask + block_samples = () - for index_block, block in enumerate(self.transformer_blocks): + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, + block_attention_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, ) block_samples = block_samples + (hidden_states,) @@ -267,6 +298,15 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[QwenImageControlNetOutput, Tuple]: + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be " + "removed in version 0.39.0. The text sequence length is now automatically inferred from " + "`encoder_hidden_states` and `encoder_hidden_states_mask`.", + standard_warn=False, + ) # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1: @@ -281,7 +321,6 @@ def forward( encoder_hidden_states_mask=encoder_hidden_states_mask, timestep=timestep, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py new file mode 100644 index 000000000000..3f79ec925419 --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -0,0 +1,845 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Literal, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin +from ...models.attention_processor import Attention +from ...models.normalization import RMSNorm +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_dispatch import dispatch_attention_fn +from ..controlnets.controlnet import zero_module +from ..modeling_utils import ModelMixin + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +# Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +# Copied from diffusers.models.transformers.transformer_z_image.ZSingleStreamAttnProcessor +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +# Copied from diffusers.models.transformers.transformer_z_image.FeedForward +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +# Copied from diffusers.models.transformers.transformer_z_image.select_per_token +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + +@maybe_allow_in_graph +# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, + ): + if self.modulation: + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +# Copied from diffusers.models.transformers.transformer_z_image.RopeEmbedder +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +@maybe_allow_in_graph +class ZImageControlTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + # Control variant start + self.block_id = block_id + if block_id == 0: + self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) + self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) + + def forward( + self, + c: torch.Tensor, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + # Control + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + # Compared to `ZImageTransformerBlock` x -> c + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + c = c + gate_msa * self.attention_norm2(attn_out) + + # FFN block + c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis) + c = c + self.attention_norm2(attn_out) + + # FFN block + c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c))) + + # Control + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + control_layers_places: List[int] = None, + control_refiner_layers_places: List[int] = None, + control_in_dim=None, + add_control_noise_refiner: Optional[Literal["control_layers", "control_noise_refiner"]] = None, + all_patch_size=(2,), + all_f_patch_size=(1,), + dim=3840, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + ): + super().__init__() + self.control_layers_places = control_layers_places + self.control_in_dim = control_in_dim + self.control_refiner_layers_places = control_refiner_layers_places + self.add_control_noise_refiner = add_control_noise_refiner + + assert 0 in self.control_layers_places + + # control blocks + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) + for i in self.control_layers_places + ] + ) + + # control patch embeddings + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + if self.add_control_noise_refiner == "control_layers": + self.control_noise_refiner = None + elif self.add_control_noise_refiner == "control_noise_refiner": + self.control_noise_refiner = nn.ModuleList( + [ + ZImageControlTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + block_id=layer_id, + ) + for layer_id in range(n_refiner_layers) + ] + ) + else: + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + self.t_scale: Optional[float] = None + self.t_embedder: Optional[TimestepEmbedder] = None + self.all_x_embedder: Optional[nn.ModuleDict] = None + self.cap_embedder: Optional[nn.Sequential] = None + self.rope_embedder: Optional[RopeEmbedder] = None + self.noise_refiner: Optional[nn.ModuleList] = None + self.context_refiner: Optional[nn.ModuleList] = None + self.x_pad_token: Optional[nn.Parameter] = None + self.cap_pad_token: Optional[nn.Parameter] = None + + @classmethod + def from_transformer(cls, controlnet, transformer): + controlnet.t_scale = transformer.t_scale + controlnet.t_embedder = transformer.t_embedder + controlnet.all_x_embedder = transformer.all_x_embedder + controlnet.cap_embedder = transformer.cap_embedder + controlnet.rope_embedder = transformer.rope_embedder + controlnet.noise_refiner = transformer.noise_refiner + controlnet.context_refiner = transformer.context_refiner + controlnet.x_pad_token = transformer.x_pad_token + controlnet.cap_pad_token = transformer.cap_pad_token + return controlnet + + @staticmethod + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.create_coordinate_grid + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids + def _pad_with_ids( + self, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, + ): + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed + def patchify_and_embed( + self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int + ): + """Patchify for basic mode: single image per batch item.""" + device = all_image[0].device + all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] + + for image, cap_feat in zip(all_image, all_cap_feats): + # Caption + cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( + cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device + ) + all_cap_out.append(cap_out) + all_cap_pos_ids.append(cap_pos_ids) + all_cap_pad_mask.append(cap_pad_mask) + + # Image + img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) + img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( + img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device + ) + all_img_out.append(img_out) + all_img_size.append(size) + all_img_pos_ids.append(img_pos_ids) + all_img_pad_mask.append(img_pad_mask) + + return ( + all_img_out, + all_cap_out, + all_img_size, + all_img_pos_ids, + all_cap_pos_ids, + all_img_pad_mask, + all_cap_pad_mask, + ) + + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + all_image_out = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return all_image_out + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + control_context: List[torch.Tensor], + conditioning_scale: float = 1.0, + patch_size=2, + f_patch_size=1, + ): + if ( + self.t_scale is None + or self.t_embedder is None + or self.all_x_embedder is None + or self.cap_embedder is None + or self.rope_embedder is None + or self.noise_refiner is None + or self.context_refiner is None + or self.x_pad_token is None + or self.cap_pad_token is None + ): + raise ValueError( + "Required modules are `None`, use `from_transformer` to share required modules from `transformer`." + ) + + assert patch_size in self.config.all_patch_size + assert f_patch_size in self.config.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + control_context = self.patchify(control_context, patch_size, f_patch_size) + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + + # x embed & refine + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if self.add_control_noise_refiner is not None: + if self.add_control_noise_refiner == "control_layers": + layers = self.control_layers + elif self.add_control_noise_refiner == "control_noise_refiner": + layers = self.control_noise_refiner + else: + raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.") + for layer in layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + control_context = self._gradient_checkpointing_func( + layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input + ) + else: + control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input) + + hints = torch.unbind(control_context)[:-1] + control_context = torch.unbind(control_context)[-1] + noise_refiner_block_samples = { + layer_idx: hints[idx] * conditioning_scale + for idx, layer_idx in enumerate(self.control_refiner_layers_places) + } + else: + noise_refiner_block_samples = None + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer_idx, layer in enumerate(self.noise_refiner): + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + if noise_refiner_block_samples is not None: + if layer_idx in noise_refiner_block_samples: + x = x + noise_refiner_block_samples[layer_idx] + else: + for layer_idx, layer in enumerate(self.noise_refiner): + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + if noise_refiner_block_samples is not None: + if layer_idx in noise_refiner_block_samples: + x = x + noise_refiner_block_samples[layer_idx] + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + ## ControlNet start + if not self.add_control_noise_refiner: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # unified + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + control_context_unified = self._gradient_checkpointing_func( + layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + else: + control_context_unified = layer( + control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + + hints = torch.unbind(control_context_unified)[:-1] + controlnet_block_samples = { + layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) + } + return controlnet_block_samples diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b48ba6b4873..6d2e8df9c286 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -47,6 +47,7 @@ is_torch_version, logging, ) +from ..utils.distributed_utils import is_torch_dist_rank_zero logger = logging.get_logger(__name__) @@ -354,8 +355,9 @@ def _load_shard_file( state_dict_folder=None, ignore_mismatched_sizes=False, low_cpu_mem_usage=False, + disable_mmap=False, ): - state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap) mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, @@ -401,6 +403,7 @@ def _load_shard_files_with_threadpool( state_dict_folder=None, ignore_mismatched_sizes=False, low_cpu_mem_usage=False, + disable_mmap=False, ): # Do not spawn anymore workers than you need num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS) @@ -427,10 +430,15 @@ def _load_shard_files_with_threadpool( state_dict_folder=state_dict_folder, ignore_mismatched_sizes=ignore_mismatched_sizes, low_cpu_mem_usage=low_cpu_mem_usage, + disable_mmap=disable_mmap, ) + tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"} + if not is_torch_dist_rank_zero(): + tqdm_kwargs["disable"] = True + with ThreadPoolExecutor(max_workers=num_workers) as executor: - with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar: + with logging.tqdm(**tqdm_kwargs) as pbar: futures = [executor.submit(load_one, shard_file) for shard_file in shard_files] for future in as_completed(futures): result = future.result() diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f06822c741ca..63e50af61771 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -59,11 +59,8 @@ is_torch_version, logging, ) -from ..utils.hub_utils import ( - PushToHubMixin, - load_or_create_model_card, - populate_model_card, -) +from ..utils.distributed_utils import is_torch_dist_rank_zero +from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card from ..utils.torch_utils import empty_device_cache from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig from .model_loading_utils import ( @@ -531,6 +528,8 @@ def enable_group_offload( record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[str] = None, + exclude_kwargs: Optional[str] = None, ) -> None: r""" Activates group offloading for the current model. @@ -570,6 +569,7 @@ def enable_group_offload( f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " f"open an issue at https://github.com/huggingface/diffusers/issues." ) + apply_group_offloading( module=self, onload_device=onload_device, @@ -581,6 +581,8 @@ def enable_group_offload( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + exclude_kwargs=exclude_kwargs, ) def set_attention_backend(self, backend: str) -> None: @@ -597,6 +599,7 @@ def set_attention_backend(self, backend: str) -> None: from .attention import AttentionModuleMixin from .attention_dispatch import ( AttentionBackendName, + _AttentionBackendRegistry, _check_attention_backend_requirements, _maybe_download_kernel_for_backend, ) @@ -605,6 +608,16 @@ def set_attention_backend(self, backend: str) -> None: from .attention_processor import Attention, MochiAttention logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + + parallel_config_set = False + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if getattr(processor, "_parallel_config", None) is not None: + parallel_config_set = True + break backend = backend.lower() available_backends = {x.value for x in AttentionBackendName.__members__.values()} @@ -612,10 +625,17 @@ def set_attention_backend(self, backend: str) -> None: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) backend = AttentionBackendName(backend) + if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend): + compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel) + raise ValueError( + f"Context parallelism is enabled but current attention backend '{backend.value}' " + f"does not support context parallelism. " + f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()`." + ) + _check_attention_backend_requirements(backend) _maybe_download_kernel_for_backend(backend) - attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): continue @@ -624,6 +644,9 @@ def set_attention_backend(self, backend: str) -> None: continue processor._attention_backend = backend + # Important to set the active backend so that it propagates gracefully throughout. + _AttentionBackendRegistry.set_active_backend(backend) + def reset_attention_backend(self) -> None: """ Resets the attention backend for the model. Following calls to `forward` will use the environment default, if @@ -1304,6 +1327,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P keep_in_fp32_modules=keep_in_fp32_modules, dduf_entries=dduf_entries, is_parallel_loading_enabled=is_parallel_loading_enabled, + disable_mmap=disable_mmap, ) loading_info = { "missing_keys": missing_keys, @@ -1358,12 +1382,12 @@ def cuda(self, *args, **kwargs): # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - if getattr(self, "is_loaded_in_8bit", False): + if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"): raise ValueError( - "Calling `cuda()` is not supported for `8-bit` quantized models. " - " Please use the model as it is, since the model has already been set to the correct devices." + "Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0." ) - elif is_bitsandbytes_version("<", "0.43.2"): + elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"): raise ValueError( "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." @@ -1410,17 +1434,16 @@ def to(self, *args, **kwargs): ) if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - if getattr(self, "is_loaded_in_8bit", False): + if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"): raise ValueError( - "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." + "Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0." ) - elif is_bitsandbytes_version("<", "0.43.2"): + elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"): raise ValueError( "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." ) - if _is_group_offload_enabled(self) and device_arg_or_kwarg_present: logger.warning( f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported." @@ -1536,7 +1559,7 @@ def enable_parallelism( f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " f"is using backend '{attention_backend.value}' which does not support context parallelism. " f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " - f"calling `enable_parallelism()`." + f"calling `model.enable_parallelism()`." ) # All modules use the same attention processor and backend. We don't need to @@ -1590,6 +1613,7 @@ def _load_pretrained_model( offload_folder: Optional[Union[str, os.PathLike]] = None, dduf_entries: Optional[Dict[str, DDUFEntry]] = None, is_parallel_loading_enabled: Optional[bool] = False, + disable_mmap: bool = False, ): model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -1658,6 +1682,7 @@ def _load_pretrained_model( state_dict_folder=state_dict_folder, ignore_mismatched_sizes=ignore_mismatched_sizes, low_cpu_mem_usage=low_cpu_mem_usage, + disable_mmap=disable_mmap, ) if is_parallel_loading_enabled: @@ -1667,7 +1692,10 @@ def _load_pretrained_model( else: shard_files = resolved_model_file if len(resolved_model_file) > 1: - shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"} + if not is_torch_dist_rank_zero(): + shard_tqdm_kwargs["disable"] = True + shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs) for shard_file in shard_files: offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index c0b4ad40055a..a98cb491146b 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -366,7 +366,12 @@ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwarg hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor.contiguous()) + # Only use contiguous() during training to avoid DDP gradient stride mismatch warning. + # In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU. + # Issue: https://github.com/huggingface/diffusers/issues/12975 + if self.training: + input_tensor = input_tensor.contiguous() + input_tensor = self.conv_shortcut(input_tensor) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a42f6b2716e1..d9d1b27a1e40 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,13 +27,16 @@ from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel + from .transformer_glm_image import GlmImageTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel + from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel + from .transformer_ltx2 import LTX2VideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py index 373b470ae37b..2b0c2667072b 100644 --- a/src/diffusers/models/transformers/transformer_cosmos.py +++ b/src/diffusers/models/transformers/transformer_cosmos.py @@ -439,6 +439,9 @@ def __init__( rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0), concat_padding_mask: bool = True, extra_pos_embed_type: Optional[str] = "learnable", + use_crossattn_projection: bool = False, + crossattn_proj_in_channels: int = 1024, + encoder_hidden_states_channels: int = 1024, ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim @@ -485,6 +488,12 @@ def __init__( hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False ) + if self.config.use_crossattn_projection: + self.crossattn_proj = nn.Sequential( + nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True), + nn.GELU(), + ) + self.gradient_checkpointing = False def forward( @@ -524,6 +533,7 @@ def forward( post_patch_num_frames = num_frames // p_t post_patch_height = height // p_h post_patch_width = width // p_w + hidden_states = self.patch_embed(hidden_states) hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C] @@ -546,6 +556,9 @@ def forward( else: assert False + if self.config.use_crossattn_projection: + encoder_hidden_states = self.crossattn_proj(encoder_hidden_states) + # 5. Transformer blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437f2..1a4464432425 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -717,11 +717,7 @@ def forward( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index c10bf3ed4f7b..9cadfcefc497 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -585,7 +585,13 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: class Flux2TimestepGuidanceEmbeddings(nn.Module): - def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False): + def __init__( + self, + in_channels: int = 256, + embedding_dim: int = 6144, + bias: bool = False, + guidance_embeds: bool = True, + ): super().__init__() self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -593,20 +599,24 @@ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias ) - self.guidance_embedder = TimestepEmbedding( - in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias - ) + if guidance_embeds: + self.guidance_embedder = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias + ) + else: + self.guidance_embedder = None def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D) - guidance_proj = self.time_proj(guidance) - guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) - - time_guidance_emb = timesteps_emb + guidance_emb - - return time_guidance_emb + if guidance is not None and self.guidance_embedder is not None: + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D) + time_guidance_emb = timesteps_emb + guidance_emb + return time_guidance_emb + else: + return timesteps_emb class Flux2Modulation(nn.Module): @@ -698,6 +708,7 @@ def __init__( axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32), rope_theta: int = 2000, eps: float = 1e-6, + guidance_embeds: bool = True, ): super().__init__() self.out_channels = out_channels or in_channels @@ -708,7 +719,10 @@ def __init__( # 2. Combined timestep + guidance embedding self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( - in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False + in_channels=timestep_guidance_channels, + embedding_dim=self.inner_dim, + bias=False, + guidance_embeds=guidance_embeds, ) # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.) @@ -815,7 +829,9 @@ def forward( # 1. Calculate timestep embedding and modulation parameters timestep = timestep.to(hidden_states.dtype) * 1000 - guidance = guidance.to(hidden_states.dtype) * 1000 + + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 temb = self.time_guidance_embed(timestep, guidance) @@ -835,14 +851,8 @@ def forward( if txt_ids.ndim == 3: txt_ids = txt_ids[0] - if is_torch_npu_available(): - freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) - image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) - freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) - text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) - else: - image_rotary_emb = self.pos_embed(img_ids) - text_rotary_emb = self.pos_embed(txt_ids) + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) concat_rotary_emb = ( torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py new file mode 100644 index 000000000000..6f7ed2fca1c9 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -0,0 +1,672 @@ +# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LayerNorm, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class GlmImageCombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + # (B, 2 * condition_dim) + condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + conditioning = F.silu(conditioning) + + return conditioning + + +class GlmImageImageProjector(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + hidden_states = hidden_states.reshape( + batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class GlmImageAdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) + + emb = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class GlmImageLayerKVCache: + """KV cache for GlmImage model. + Supports per-sample caching for batch processing where each sample may have different condition images. + """ + + def __init__(self): + self.k_caches: List[Optional[torch.Tensor]] = [] + self.v_caches: List[Optional[torch.Tensor]] = [] + self.mode: Optional[str] = None # "write", "read", "skip" + self.current_sample_idx: int = 0 # Current sample index for writing + + def store(self, k: torch.Tensor, v: torch.Tensor): + """Store KV cache for the current sample.""" + # k, v shape: (1, seq_len, num_heads, head_dim) + if len(self.k_caches) <= self.current_sample_idx: + # First time storing for this sample + self.k_caches.append(k) + self.v_caches.append(v) + else: + # Append to existing cache for this sample (multiple condition images) + self.k_caches[self.current_sample_idx] = torch.cat([self.k_caches[self.current_sample_idx], k], dim=1) + self.v_caches[self.current_sample_idx] = torch.cat([self.v_caches[self.current_sample_idx], v], dim=1) + + def get(self, k: torch.Tensor, v: torch.Tensor): + """Get combined KV cache for all samples in the batch. + + Args: + k: Current key tensor, shape (batch_size, seq_len, num_heads, head_dim) + v: Current value tensor, shape (batch_size, seq_len, num_heads, head_dim) + Returns: + Combined key and value tensors with cached values prepended. + """ + batch_size = k.shape[0] + num_cached_samples = len(self.k_caches) + if num_cached_samples == 0: + return k, v + if num_cached_samples == 1: + # Single cache, expand for all batch samples (shared condition images) + k_cache_expanded = self.k_caches[0].expand(batch_size, -1, -1, -1) + v_cache_expanded = self.v_caches[0].expand(batch_size, -1, -1, -1) + elif num_cached_samples == batch_size: + # Per-sample cache, concatenate along batch dimension + k_cache_expanded = torch.cat(self.k_caches, dim=0) + v_cache_expanded = torch.cat(self.v_caches, dim=0) + else: + # Mismatch: try to handle by repeating the caches + # This handles cases like num_images_per_prompt > 1 + repeat_factor = batch_size // num_cached_samples + if batch_size % num_cached_samples == 0: + k_cache_list = [] + v_cache_list = [] + for i in range(num_cached_samples): + k_cache_list.append(self.k_caches[i].expand(repeat_factor, -1, -1, -1)) + v_cache_list.append(self.v_caches[i].expand(repeat_factor, -1, -1, -1)) + k_cache_expanded = torch.cat(k_cache_list, dim=0) + v_cache_expanded = torch.cat(v_cache_list, dim=0) + else: + raise ValueError( + f"Cannot match {num_cached_samples} cached samples to batch size {batch_size}. " + f"Batch size must be a multiple of the number of cached samples." + ) + + k_combined = torch.cat([k_cache_expanded, k], dim=1) + v_combined = torch.cat([v_cache_expanded, v], dim=1) + return k_combined, v_combined + + def clear(self): + self.k_caches = [] + self.v_caches = [] + self.mode = None + self.current_sample_idx = 0 + + def next_sample(self): + """Move to the next sample for writing.""" + self.current_sample_idx += 1 + + +class GlmImageKVCache: + """Container for all layers' KV caches. + Supports per-sample caching for batch processing where each sample may have different condition images. + """ + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + return self.caches[layer_idx] + + def set_mode(self, mode: Optional[str]): + if mode is not None and mode not in ["write", "read", "skip"]: + raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'") + for cache in self.caches: + cache.mode = mode + + def next_sample(self): + """Move to the next sample for writing. Call this after processing + all condition images for one batch sample.""" + for cache in self.caches: + cache.next_sample() + + def clear(self): + for cache in self.caches: + cache.clear() + + +class GlmImageAttnProcessor: + """ + Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + + The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size, + text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = encoder_hidden_states.dtype + + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query).to(dtype=dtype) + if attn.norm_k is not None: + key = attn.norm_k(key).to(dtype=dtype) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query[:, text_seq_length:, :, :] = apply_rotary_emb( + query[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2 + ) + key[:, text_seq_length:, :, :] = apply_rotary_emb( + key[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2 + ) + + if kv_cache is not None: + if kv_cache.mode == "write": + kv_cache.store(key, value) + elif kv_cache.mode == "read": + key, value = kv_cache.get(key, value) + elif kv_cache.mode == "skip": + pass + + # 4. Attention + if attention_mask is not None: + text_attn_mask = attention_mask + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + text_attn_mask = text_attn_mask.float().to(query.device) + mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) + mix_attn_mask[:, :text_seq_length] = text_attn_mask + mix_attn_mask = mix_attn_mask.unsqueeze(2) + attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2) + attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 5. Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +@maybe_allow_in_graph +class GlmImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ) -> None: + super().__init__() + + # 1. Attention + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-5, + processor=GlmImageAttnProcessor(), + ) + + # 2. Feedforward + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Timestep conditioning + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, temb) + + # 2. Attention + attention_kwargs = attention_kwargs or {} + + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + **attention_kwargs, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) + + # 3. Feedforward + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class GlmImageRotaryPosEmbed(nn.Module): + def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(height) + w_seq = torch.arange(width) + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + # Create position matrices for height and width + # [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1) + freqs_w = freqs_w.unsqueeze(0) + # Broadcast freqs_h and freqs_w to [height, width, dim//4] + freqs_h = freqs_h.expand(height, width, -1) + freqs_w = freqs_w.expand(height, width, -1) + + # Concatenate along last dimension to get [height, width, dim//2] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim] + freqs = freqs.reshape(height * width, -1) + return (freqs.cos(), freqs.sin()) + + +class GlmImageAdaLayerNormContinuous(nn.Module): + """ + GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the + Linear on conditioning embedding. + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + ): + super().__init__() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # *** NO SiLU here *** + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): + r""" + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + attention_head_dim (`int`, defaults to `40`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `64`): + The number of heads to use for multi-head attention. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_embed_dim (`int`, defaults to `1472`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + condition_dim (`int`, defaults to `256`): + The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, + crop_coords). + pos_embed_max_size (`int`, defaults to `128`): + The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added + to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 + means that the maximum supported height and width for image generation is `128 * vae_scale_factor * + patch_size => 128 * 8 * 2 => 2048`. + sample_size (`int`, defaults to `128`): + The base resolution of input latents. If height/width is not provided during generation, this value is used + to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` + """ + + _supports_gradient_checkpointing = True + _no_split_modules = [ + "GlmImageTransformerBlock", + "GlmImageImageProjector", + "GlmImageImageProjector", + ] + _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"] + _skip_keys = ["kv_caches"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + text_embed_dim: int = 1472, + time_embed_dim: int = 512, + condition_dim: int = 256, + prior_vq_quantizer_codebook_size: int = 16384, + ): + super().__init__() + + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + pooled_projection_dim = 2 * 2 * condition_dim + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels + + # 1. RoPE + self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + + # 2. Patch & Text-timestep embedding + self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) + self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") + + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=time_embed_dim, + ) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + for _ in range(num_layers) + ] + ) + + # 4. Output projection + self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + prior_token_id: torch.Tensor, + prior_token_drop: torch.Tensor, + timestep: torch.LongTensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, + kv_caches: Optional[GlmImageKVCache] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: + batch_size, num_channels, height, width = hidden_states.shape + + # 1. RoPE + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) + + # 2. Patch & Timestep embeddings + p = self.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = self.image_projector(hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) + prior_embedding = self.prior_token_embedding(prior_token_id) + prior_embedding[prior_token_drop] *= 0.0 + prior_hidden_states = self.prior_projector(prior_embedding) + + hidden_states = hidden_states + prior_hidden_states + + temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) + + # 3. Transformer blocks + for idx, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_caches[idx] if kv_caches is not None else None, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_cache=kv_caches[idx] if kv_caches is not None else None, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + + # Rearrange tensor from (B, H_p, W_p, C, p, p) to (B, C, H_p * p, W_p * p) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index fb0ce1a30ff9..4f0775ac9fa0 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -312,7 +312,6 @@ def forward( timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) pooled_projections = self.text_embedder(pooled_projection) - conditioning = timesteps_emb + pooled_projections token_replace_emb = None if self.image_condition_type == "token_replace": @@ -324,8 +323,9 @@ def forward( if self.guidance_embedder is not None: guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) - conditioning = conditioning + guidance_emb - + conditioning = timesteps_emb + guidance_emb + pooled_projections + else: + conditioning = timesteps_emb + pooled_projections return conditioning, token_replace_emb diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 76a02cb1a886..293ba996ea98 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -184,19 +184,32 @@ class HunyuanVideo15TimeEmbedding(nn.Module): The dimension of the output embedding. """ - def __init__(self, embedding_dim: int): + def __init__(self, embedding_dim: int, use_meanflow: bool = False): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.use_meanflow = use_meanflow + self.time_proj_r = None + self.timestep_embedder_r = None + if use_meanflow: + self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + def forward( self, timestep: torch.Tensor, + timestep_r: Optional[torch.Tensor] = None, ) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) + if timestep_r is not None: + timesteps_proj_r = self.time_proj_r(timestep_r) + timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype)) + timesteps_emb = timesteps_emb + timesteps_emb_r + return timesteps_emb @@ -567,6 +580,7 @@ def __init__( # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205 target_size: int = 640, # did not name sample_size since it is in pixel spaces task_type: str = "i2v", + use_meanflow: bool = False, ) -> None: super().__init__() @@ -582,7 +596,7 @@ def __init__( ) self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim) - self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim) + self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow=use_meanflow) self.cond_type_embed = nn.Embedding(3, inner_dim) @@ -612,6 +626,7 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, + timestep_r: Optional[torch.LongTensor] = None, encoder_hidden_states_2: Optional[torch.Tensor] = None, encoder_attention_mask_2: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, @@ -643,7 +658,7 @@ def forward( image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings - temb = self.time_embed(timestep) + temb = self.time_embed(timestep, timestep_r=timestep_r) hidden_states = self.x_embedder(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 316e79da4fd6..57b28991d255 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -165,9 +165,8 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, time): - args = torch.outer(time, self.freqs.to(device=time.device)) + args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed @@ -269,7 +268,6 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.weight.data.zero_() self.out_layer.bias.data.zero_() - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): return self.out_layer(self.activation(x)) @@ -525,6 +523,7 @@ class Kandinsky5Transformer3DModel( "Kandinsky5TransformerEncoderBlock", "Kandinsky5TransformerDecoderBlock", ] + _keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py new file mode 100644 index 000000000000..3d38da1dfcf5 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -0,0 +1,546 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "LongCatImageAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class LongCatImageAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "LongCatImageAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + +class LongCatImageAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = LongCatImageAttnProcessor + _available_processors = [ + LongCatImageAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class LongCatImageSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = LongCatImageAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=LongCatImageAttnProcessor(), + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class LongCatImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = LongCatImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=LongCatImageAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class LongCatImagePosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class LongCatImageTimestepEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timestep, hidden_dtype): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + return timesteps_emb + + +class LongCatImageTransformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, + AttentionMixin, +): + """ + The Transformer model introduced in Longcat-Image. + """ + + _supports_gradient_checkpointing = True + _repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 3584, + pooled_projection_dim: int = 3584, + axes_dims_rope: List[int] = [16, 56, 56], + ): + super().__init__() + self.out_channels = in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.pooled_projection_dim = pooled_projection_dim + + self.pos_embed = LongCatImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_embed = LongCatImageTimestepEmbeddings(embedding_dim=self.inner_dim) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + LongCatImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + LongCatImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + self.use_checkpoint = [True] * num_layers + self.use_single_checkpoint = [True] * num_single_layers + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + temb = self.time_embed(timestep, hidden_states.dtype) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_single_checkpoint[index_block]: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py new file mode 100644 index 000000000000..b88f096e8033 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -0,0 +1,1350 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +def apply_split_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + # cos is (#b, h, t, r) -> reshape x to (b, h, t, dim_per_head) + # The cos/sin batch dim may only be broadcastable, so take batch size from x + b = x.shape[0] + _, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + # Split last dim (2*r) into (d=2, r) + last = x.shape[-1] + if last % 2 != 0: + raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.") + r = last // 2 + + # (..., 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float + first_x = split_x[..., :1, :] # (..., 1, r) + second_x = split_x[..., 1:, :] # (..., 1, r) + + cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + + out = out.to(dtype=x_dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output + of the model. This is typically a video (spatiotemporal) output. + audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`): + The audio output of the audiovisual model. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. In particular, the number of modulation parameters to be calculated is now configurable. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters which will be calculated in the first return argument. The default of 6 + is standard, but sometimes we may want to have a different (usually smaller) number of modulation + parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. + Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can + support audio-to-video (a2v) and video-to-audio (v2a) cross attention. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module, AttentionModuleMixin): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + self.rope_type = rope_type + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim, + audio_cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + + # 1. Self-Attention (video and audio) + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 2. Prompt Cross-Attention + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio + self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 4. Feedforward layers + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) + + # 5. Per-Layer Modulation Parameters + # Self-Attention / Feedforward AdaLayerNorm-Zero mod params + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + + # Per-layer a2v, v2a Cross-Attention mod params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + temb_audio: torch.Tensor, + temb_ca_scale_shift: torch.Tensor, + temb_ca_audio_scale_shift: torch.Tensor, + temb_ca_gate: torch.Tensor, + temb_ca_audio_gate: torch.Tensor, + video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_video_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ca_audio_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + a2v_cross_attention_mask: Optional[torch.Tensor] = None, + v2a_cross_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + + # 1. Video and Audio Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video and Audio Cross-Attention with the text embeddings + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) + + # Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + + Args: + causal_offset (`int`, *optional*, defaults to `1`): + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE + treats the very first frame differently), but could also be 0 (for non-causal modeling). + """ + + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: Tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.rope_type = rope_type + + self.base_num_frames = base_num_frames + self.num_attention_heads = num_attention_heads + + # Video-specific + self.base_height = base_height + self.base_width = base_width + + # Audio-specific + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + self.double_precision = double_precision + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + fps: float = 24.0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel + space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) + where + - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the video latents. + num_frames (`int`): + Number of latent frames in the video latents. + height (`int`): + Latent height of the video latents. + width (`int`): + Latent width of the video latents. + device (`torch.device`): + Device on which to create the video grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2]. + """ + + # 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width) + # Always compute rope in fp32 + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + # indexing='ij' ensures that the dimensions are kept in order as (frames, height, width) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches + + # 2. Get the patch boundaries with respect to the latent video grid + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + # Reshape to (batch_size, 3, num_patches, 2) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + # 3. Calculate the pixel space patch boundaries from the latent boundaries. + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # This is the (frame, height, width) dim + # Apply per-axis scaling to convert latent coordinates to pixel space coordinates + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift + # and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + + # Scale the temporal coordinates by the video FPS + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. + This will ultimately have shape (batch_size, 3, num_patches, 2) where + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the audio latents. + num_frames (`int`): + Number of latent frames in the audio latents. + device (`torch.device`): + Device on which to create the audio grid. + shift (`int`, *optional*, defaults to `0`): + Offset on the latent indices. Different shift values correspond to different overlapping windows with + respect to the same underlying latent grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2]. + """ + + # 1. Generate coordinates in the frame (time) dimension. + # Always compute rope in fp32 + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid + audio_scale_factor = self.scale_factors[0] + # Scale back to mel spectrogram space + grid_start_mel = grid_f * audio_scale_factor + # Handle first frame causal offset, ensuring non-negative timestamps + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + # Convert mel bins back into seconds + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2] + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2] + audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2] + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, coords: torch.Tensor, device: Optional[Union[str, torch.device]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or coords.device + + # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) + num_pos_dims = coords.shape[1] + + # 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # position index + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] + + # 2. Get coordinates as a fraction of the base data shape + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + # [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims] + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin + num_rope_elems = num_pos_dims * 2 + + # 3. Create a 1D grid of frequencies for RoPE + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # (self.dim // num_elems,) + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] + freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] + + # 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim + # TODO: consider implementing this as a utility and reuse in `connectors.py`. + # src/diffusers/pipelines/ltx2/connectors.py + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2VideoTransformer3DModel( + ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin +): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `2048 `): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + in_channels: int = 128, # Video Arguments + out_channels: Optional[int] = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: Tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, # Audio Arguments + audio_out_channels: Optional[int] = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, # Shared arguments + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep Modulation Params and Embedding + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding + # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + + # 3.2. Global Cross Attention Modulation Parameters + # Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params, + # which are then further modified by per-block modulaton params in each transformer block. + # There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and + # video-to-audio (v2a) cross attention + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + # Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys + # and values (KV)) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + # Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys + # and values (KV)) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output Layer Scale/Shift Modulation parameters + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. Rotary Positional Embeddings (RoPE) + # Self-Attention + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # Audio-to-Video, Video-to-Audio Cross-Attention + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # 5. Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=audio_num_attention_heads, + audio_attention_head_dim=audio_attention_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + audio_timestep: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 24.0, + audio_num_frames: Optional[int] = None, + video_coords: Optional[torch.Tensor] = None, + audio_coords: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + """ + Forward pass for LTX-2.0 audiovisual video transformer. + + Args: + hidden_states (`torch.Tensor`): + Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`. + audio_hidden_states (`torch.Tensor`): + Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`. + encoder_hidden_states (`torch.Tensor`): + Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + audio_encoder_hidden_states (`torch.Tensor`): + Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + timestep (`torch.Tensor`): + Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by + `self.config.timestep_scale_multiplier`. + audio_timestep (`torch.Tensor`, *optional*): + Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation + params. This is only used by certain pipelines such as the I2V pipeline. + encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. + audio_encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. + num_frames (`int`, *optional*): + The number of latent video frames. Used if calculating the video coordinates for RoPE. + height (`int`, *optional*): + The latent video height. Used if calculating the video coordinates for RoPE. + width (`int`, *optional*): + The latent video width. Used if calculating the video coordinates for RoPE. + fps: (`float`, *optional*, defaults to `24.0`): + The desired frames per second of the generated video. Used if calculating the video coordinates for + RoPE. + audio_num_frames: (`int`, *optional*): + The number of latent audio frames. Used if calculating the audio coordinates for RoPE. + video_coords (`torch.Tensor`, *optional*): + The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + audio_coords (`torch.Tensor`, *optional*): + The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + attention_kwargs (`Dict[str, Any]`, *optional*): + Optional dict of keyword args to be passed to the attention processor. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple. + + Returns: + `AudioVisualModelOutput` or `tuple`: + If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a + `tuple` is returned where the first element is the denoised video latent patch sequence and the second + element is the denoised audio latent patch sequence. + """ + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # Determine timestep for audio. + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters + # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer + # modulation with scale_shift_table (and similarly for audio) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + # 3.2. Prepare global modality cross attention modulation parameters + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + + # 5. Run transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + + # 6. Output layers (including unpatchification) + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 0a09aa720b3f..139ceaefa4e9 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import is_torch_npu_available, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -530,11 +530,7 @@ def forward( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index ccbc83ffca03..a87c120fdcd7 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -16,7 +16,6 @@ import torch from torch import nn -from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging @@ -532,7 +531,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // patch_size)` is the number of patches. """ - return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + b, c, h, w = img.shape + p = patch_size + + # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions + img = img.reshape(b, c, h // p, p, w // p, p) + + # Permute to (B, H//p, W//p, C, p, p) using einsum + # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width + img = torch.einsum("nchpwq->nhwcpq", img) + + # Flatten to (B, L, C * p * p) + img = img.reshape(b, -1, c * p * p) + return img def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor: @@ -554,12 +565,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te Reconstructed image tensor of shape `(B, C, H, W)`. """ if isinstance(shape, tuple): - shape = shape[-2:] + h, w = shape[-2:] elif isinstance(shape, torch.Tensor): - shape = (int(shape[0]), int(shape[1])) + h, w = (int(shape[0]), int(shape[1])) else: raise NotImplementedError(f"shape type {type(shape)} not supported") - return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + b, l, d = seq.shape + p = patch_size + c = d // (p * p) + + # Reshape back to grid structure: (B, H//p, W//p, C, p, p) + seq = seq.reshape(b, h // p, w // p, c, p, p) + + # Permute back to image layout: (B, C, H//p, p, W//p, p) + # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width + seq = torch.einsum("nhwcpq->nchpwq", seq) + + # Final reshape to (B, C, H, W) + seq = seq.reshape(b, c, h, w) + return seq class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): @@ -694,6 +719,7 @@ def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype max_period=self.time_max_period, scale=self.time_factor, flip_sin_to_cos=True, # Match original cos, sin order + downscale_freq_shift=0.0, ).to(dtype) ) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..cf11d8e01fb4 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -14,6 +14,7 @@ import functools import math +from math import prod from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -23,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -141,18 +142,53 @@ def apply_rotary_emb_qwen( return x_out.type_as(x) +def compute_text_seq_len_from_mask( + encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: Optional[torch.Tensor] +) -> Tuple[int, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask. + """ + batch_size, text_seq_len = encoder_hidden_states.shape[:2] + if encoder_hidden_states_mask is None: + return text_seq_len, None, None + + if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len): + raise ValueError( + f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match " + f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})." + ) + + if encoder_hidden_states_mask.dtype != torch.bool: + encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool) + + position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long) + active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(())) + has_active = encoder_hidden_states_mask.any(dim=1) + per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len)) + return text_seq_len, per_sample_len, encoder_hidden_states_mask + + class QwenTimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim): + def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.use_additional_t_cond = use_additional_t_cond + if use_additional_t_cond: + self.addition_t_embedding = nn.Embedding(2, embedding_dim) - def forward(self, timestep, hidden_states): + def forward(self, timestep, hidden_states, addition_t_cond=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) conditioning = timesteps_emb + if self.use_additional_t_cond: + if addition_t_cond is None: + raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.") + addition_t_emb = self.addition_t_embedding(addition_t_cond) + addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype) + conditioning = conditioning + addition_t_emb return conditioning @@ -197,21 +233,50 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - txt_seq_lens: List[int], - device: torch.device, + txt_seq_lens: Optional[List[int]] = None, + device: torch.device = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video. - txt_seq_lens (`List[int]`): - A list of integers of length batch_size representing the length of each text prompt. - device: (`torch.device`): + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. + device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `max_txt_seq_len` instead. " + "The new parameter accepts a single int or tensor value representing the maximum text sequence length.", + standard_warn=False, + ) + if max_txt_seq_len is None: + # Use max of txt_seq_lens for backward compatibility + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if max_txt_seq_len is None: + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + + # Validate batch inference with variable-sized images + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if all instances have the same size + first_fhw = video_fhw[0] + if not all(fhw == first_fhw for fhw in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedRope. " + "All images in the batch should have the same dimensions (frame, height, width). " + f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) if isinstance(video_fhw, list): video_fhw = video_fhw[0] @@ -223,8 +288,7 @@ def forward( for idx, fhw in enumerate(video_fhw): frame, height, width = fhw # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs - video_freq = self._compute_video_freqs(frame, height, width, idx) - video_freq = video_freq.to(device) + video_freq = self._compute_video_freqs(frame, height, width, idx, device) vid_freqs.append(video_freq) if self.scale_rope: @@ -232,17 +296,23 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=128) - def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor: + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None + ) -> torch.Tensor: seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: @@ -258,6 +328,147 @@ def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0 return freqs.clone().contiguous() +class QwenEmbedLayer3DRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + assert dim % 2 == 0 + freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + max_txt_seq_len: Union[int, torch.Tensor], + device: torch.device = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer + structures. + max_txt_seq_len (`int` or `torch.Tensor`): + The maximum text sequence length for RoPE computation. This should match the encoder hidden states + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + device: (`torch.device`, *optional*): + The device on which to perform the RoPE computation. + """ + # Validate batch inference with variable-sized images + # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers + if isinstance(video_fhw, list) and len(video_fhw) > 1: + # Check if this is batch inference (list of layer lists/tuples) + first_entry = video_fhw[0] + if not all(entry == first_entry for entry in video_fhw): + logger.warning( + "Batch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. " + "All images in the batch should have the same layer structure. " + f"Detected sizes: {video_fhw}. Using the first image's layer structure {first_entry} " + "for RoPE computation, which may lead to incorrect results for other images in the batch." + ) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + layer_num = len(video_fhw) - 1 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + if idx != layer_num: + video_freq = self._compute_video_freqs(frame, height, width, idx, device) + else: + ### For the condition image, we set the layer index to -1 + video_freq = self._compute_condition_freqs(frame, height, width, device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_vid_index = max(max_vid_index, layer_num) + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + vid_freqs = torch.cat(vid_freqs, dim=0) + + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=None) + def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None): + seq_lens = frame * height * width + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + @functools.lru_cache(maxsize=None) + def _compute_condition_freqs(self, frame, height, width, device: torch.device = None): + seq_lens = frame * height * width + pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs + neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs + + freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) + if self.scale_rope: + freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) + freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) + freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) + return freqs.clone().contiguous() + + class QwenDoubleStreamAttnProcessor2_0: """ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor @@ -330,7 +541,6 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, joint_key, @@ -363,7 +573,13 @@ def __call__( @maybe_allow_in_graph class QwenImageTransformerBlock(nn.Module): def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + zero_cond_t: bool = False, ): super().__init__() @@ -403,10 +619,43 @@ def __init__( self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - def _modulate(self, x, mod_params): + self.zero_cond_t = zero_cond_t + + def _modulate(self, x, mod_params, index=None): """Apply modulation to input tensor""" + # x: b l d, shift: b d, scale: b d, gate: b d shift, scale, gate = mod_params.chunk(3, dim=-1) - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + if index is not None: + # Assuming mod_params batch dim is 2*actual_batch (chunked into 2 parts) + # So shift, scale, gate have shape [2*actual_batch, d] + actual_batch = shift.size(0) // 2 + shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:] # each: [actual_batch, d] + scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:] + gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:] + + # index: [b, l] where b is actual batch size + # Expand to [b, l, 1] to match feature dimension + index_expanded = index.unsqueeze(-1) # [b, l, 1] + + # Expand chunks to [b, 1, d] then broadcast to [b, l, d] + shift_0_exp = shift_0.unsqueeze(1) # [b, 1, d] + shift_1_exp = shift_1.unsqueeze(1) # [b, 1, d] + scale_0_exp = scale_0.unsqueeze(1) + scale_1_exp = scale_1.unsqueeze(1) + gate_0_exp = gate_0.unsqueeze(1) + gate_1_exp = gate_1.unsqueeze(1) + + # Use torch.where to select based on index + shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp) + scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp) + gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp) + else: + shift_result = shift.unsqueeze(1) + scale_result = scale.unsqueeze(1) + gate_result = gate.unsqueeze(1) + + return x * (1 + scale_result) + shift_result, gate_result def forward( self, @@ -416,9 +665,13 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + modulate_index: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Get modulation parameters for both streams img_mod_params = self.img_mod(temb) # [B, 6*dim] + + if self.zero_cond_t: + temb = torch.chunk(temb, 2, dim=0)[0] txt_mod_params = self.txt_mod(temb) # [B, 6*dim] # Split modulation parameters for norm1 and norm2 @@ -427,7 +680,7 @@ def forward( # Process image stream - norm1 + modulation img_normed = self.img_norm1(hidden_states) - img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index) # Process text stream - norm1 + modulation txt_normed = self.txt_norm1(encoder_hidden_states) @@ -457,7 +710,7 @@ def forward( # Process image stream - norm2 + MLP img_normed2 = self.img_norm2(hidden_states) - img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2, modulate_index) img_mlp_output = self.img_mlp(img_modulated2) hidden_states = hidden_states + img_gate2 * img_mlp_output @@ -508,11 +761,14 @@ class QwenImageTransformer2DModel( _no_split_modules = ["QwenImageTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["QwenImageTransformerBlock"] + # Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702 _cp_plan = { - "": { + "transformer_blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "transformer_blocks.*": { + "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), }, "pos_embed": { 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), @@ -533,14 +789,22 @@ def __init__( joint_attention_dim: int = 3584, guidance_embeds: bool = False, # TODO: this should probably be removed axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + zero_cond_t: bool = False, + use_additional_t_cond: bool = False, + use_layer3d_rope: bool = False, ): super().__init__() self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim - self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + if not use_layer3d_rope: + self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) + else: + self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) - self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + self.time_text_embed = QwenTimestepProjEmbeddings( + embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond + ) self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) @@ -553,6 +817,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, + zero_cond_t=zero_cond_t, ) for _ in range(num_layers) ] @@ -562,6 +827,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False + self.zero_cond_t = zero_cond_t def forward( self, @@ -574,6 +840,7 @@ def forward( guidance: torch.Tensor = None, # TODO: this should probably be removed attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, + additional_t_cond=None, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ @@ -584,14 +851,25 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): - Mask of the input conditions. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*): + Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens. + Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern + (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention. timestep ( `torch.LongTensor`): Used to indicate denoising step. + img_shapes (`List[Tuple[int, int, int]]`, *optional*): + Image shapes for RoPE computation. + txt_seq_lens (`List[int]`, *optional*, **Deprecated**): + Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be + used to compute RoPE sequence length. + guidance (`torch.Tensor`, *optional*): + Guidance tensor for conditional generation. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_block_samples (*optional*): + ControlNet block samples to add to the transformer blocks. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. @@ -600,6 +878,15 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if txt_seq_lens is not None: + deprecate( + "txt_seq_lens", + "0.39.0", + "Passing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. " + "Please use `encoder_hidden_states_mask` instead. " + "The mask-based approach is more flexible and supports variable-length sequences.", + standard_warn=False, + ) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -618,19 +905,45 @@ def forward( hidden_states = self.img_in(hidden_states) timestep = timestep.to(hidden_states.dtype) + + if self.zero_cond_t: + timestep = torch.cat([timestep, timestep * 0], dim=0) + modulate_index = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=timestep.device, + dtype=torch.int, + ) + else: + modulate_index = None + encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + encoder_hidden_states, encoder_hidden_states_mask + ) + if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 temb = ( - self.time_text_embed(timestep, hidden_states) + self.time_text_embed(timestep, hidden_states, additional_t_cond) if guidance is None - else self.time_text_embed(timestep, guidance, hidden_states) + else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + + # Construct joint attention mask once to avoid reconstructing in every block + # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility + block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} + if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] + batch_size, image_seq_len = hidden_states.shape[:2] + image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) + joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) + block_attention_kwargs["attention_mask"] = joint_attention_mask for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -638,19 +951,22 @@ def forward( block, hidden_states, encoder_hidden_states, - encoder_hidden_states_mask, + None, # Don't pass encoder_hidden_states_mask (using attention_mask instead) temb, image_rotary_emb, + block_attention_kwargs, + modulate_index, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, + encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=attention_kwargs, + joint_attention_kwargs=block_attention_kwargs, + modulate_index=modulate_index, ) # controlnet residual @@ -659,6 +975,8 @@ def forward( interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + if self.zero_cond_t: + temb = temb.chunk(2, dim=0)[0] # Use only the image part (hidden_states) from the dual-stream blocks hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index f7693ec5d3ac..132f615f2199 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -134,7 +134,8 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=None, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -147,7 +148,8 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -552,9 +554,11 @@ class WanTransformer3DModel( "blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.*": { - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + # We need to disable the splitting of encoder_hidden_states because the image_encoder + # (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape + # of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation + # —to be indivisible by the number of devices in the CP. "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), "": { "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py index 6a47a67385a3..ae1c3095a17f 100644 --- a/src/diffusers/models/transformers/transformer_wan_animate.py +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -166,9 +166,11 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: # NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates # set to 1, which should be equivalent to a 2D convolution expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1) + x = x.to(expanded_kernel.dtype) x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels) # Main Conv2D with scaling + x = x.to(self.weight.dtype) x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) # Activation with fused bias, if using @@ -609,7 +611,8 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=None, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -622,7 +625,8 @@ def apply_rotary_emb( dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -767,7 +771,7 @@ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: return hidden_states -# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding +# Modified from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding class WanTimeTextImageEmbedding(nn.Module): def __init__( self, @@ -801,10 +805,12 @@ def forward( if timestep_seq_len is not None: timestep = timestep.unflatten(0, (-1, timestep_seq_len)) - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype - if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: - timestep = timestep.to(time_embedder_dtype) - temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + if self.time_embedder.linear_1.weight.dtype.is_floating_point: + time_embedder_dtype = self.time_embedder.linear_1.weight.dtype + else: + time_embedder_dtype = encoder_hidden_states.dtype + + temb = self.time_embedder(timestep.to(time_embedder_dtype)).type_as(encoder_hidden_states) timestep_proj = self.time_proj(self.act_fn(temb)) encoder_hidden_states = self.text_embedder(encoder_hidden_states) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5c401b9d202b..5983c34ab640 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -32,6 +32,7 @@ ADALN_EMBED_DIM = 256 SEQ_MULTI_OF = 32 +X_PAD_DIM = 64 class TimestepEmbedder(nn.Module): @@ -152,6 +153,20 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso return output +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int): super().__init__() @@ -215,12 +230,37 @@ def forward( attn_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, ): if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp # Attention block attn_out = self.attention( @@ -252,9 +292,21 @@ def __init__(self, hidden_size, out_channels): nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), ) - def forward(self, x, c): - scale = 1.0 + self.adaLN_modulation(c) - x = self.norm_final(x) * scale.unsqueeze(1) + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale x = self.linear(x) return x @@ -325,6 +377,7 @@ def __init__( norm_eps=1e-5, qk_norm=True, cap_feat_dim=2560, + siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni rope_theta=256.0, t_scale=1000.0, axes_dims=[32, 48, 48], @@ -386,6 +439,31 @@ def __init__( self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + # Optional SigLIP components (for Omni variant) + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -402,252 +480,561 @@ def __init__( self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size, + f_patch_size, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: pH = pW = patch_size pF = f_patch_size bsz = len(x) assert len(size) == bsz - for i in range(bsz): - F, H, W = size[i] - ori_len = (F // pF) * (H // pH) * (W // pW) - # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" - x[i] = ( - x[i][:ori_len] - .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) - .permute(6, 0, 3, 1, 4, 2, 5) - .reshape(self.out_channels, F, H, W) - ) - return x + + if x_pos_offsets is not None: + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + else: + # Original mode: simple unpatchify + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x @staticmethod def create_coordinate_grid(size, start=None, device=None): if start is None: start = (0 for _ in size) - axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] grids = torch.meshgrid(axes, indexing="ij") return torch.stack(grids, dim=-1) - def patchify_and_embed( + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" + pH, pW, pF = patch_size, patch_size, f_patch_size + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) + + def _pad_with_ids( self, - all_image: List[torch.Tensor], - all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, + feat: torch.Tensor, + pos_grid_size: Tuple, + pos_start: Tuple, + device: torch.device, + noise_mask_val: Optional[int] = None, ): - pH = pW = patch_size - pF = f_patch_size + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" + ori_len = len(feat) + pad_len = (-ori_len) % SEQ_MULTI_OF + total_len = ori_len + pad_len + + # Pos IDs + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) + if pad_len > 0: + pad_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(pad_len, 1) + ) + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) + pad_mask = torch.cat( + [ + torch.zeros(ori_len, dtype=torch.bool, device=device), + torch.ones(pad_len, dtype=torch.bool, device=device), + ] + ) + else: + pos_ids = ori_pos_ids + padded_feat = feat + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) + + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level + return padded_feat, pos_ids, pad_mask, total_len, noise_mask + + def patchify_and_embed( + self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int + ): + """Patchify for basic mode: single image per batch item.""" device = all_image[0].device + all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - all_cap_pos_ids = [] - all_cap_pad_mask = [] - all_cap_feats_out = [] - - for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): - ### Process Caption - cap_ori_len = len(cap_feat) - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF - # padded position ids - cap_padded_pos_ids = self.create_coordinate_grid( - size=(cap_ori_len + cap_padding_len, 1, 1), - start=(1, 0, 0), - device=device, - ).flatten(0, 2) - all_cap_pos_ids.append(cap_padded_pos_ids) - # pad mask - cap_pad_mask = torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, + for image, cap_feat in zip(all_image, all_cap_feats): + # Caption + cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( + cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device ) - all_cap_pad_mask.append( - cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + all_cap_out.append(cap_out) + all_cap_pos_ids.append(cap_pos_ids) + all_cap_pad_mask.append(cap_pad_mask) + + # Image + img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) + img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( + img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device ) + all_img_out.append(img_out) + all_img_size.append(size) + all_img_pos_ids.append(img_pos_ids) + all_img_pad_mask.append(img_pad_mask) - # padded feature - cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) - all_cap_feats_out.append(cap_padded_feat) - - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + return ( + all_img_out, + all_cap_out, + all_img_size, + all_img_pos_ids, + all_cap_pos_ids, + all_img_pad_mask, + all_cap_pad_mask, + ) - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + def patchify_and_embed_omni( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int, + f_patch_size: int, + images_noise_mask: List[List[int]], + ): + """Patchify for omni mode: multiple images per batch item with noise masks.""" + bsz = len(all_x) + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], [] + all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], [] + all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], [] - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_ori_len + cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padded_pos_ids = torch.cat( - [ - image_ori_pos_ids, - self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) - .flatten(0, 2) - .repeat(image_padding_len, 1), - ], - dim=0, - ) - all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) - # pad mask - image_pad_mask = torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - all_image_pad_mask.append( - image_pad_mask - if image_padding_len > 0 - else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) - ) - # padded feature - image_padded_feat = torch.cat( - [image, image[-1:].repeat(image_padding_len, 1)], - dim=0, - ) - all_image_out.append(image_padded_feat if image_padding_len > 0 else image) + for i in range(bsz): + num_images = len(all_x[i]) + cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], [] + cap_end_pos = [] + cap_cu_len = 1 + + # Process captions + for j, cap_item in enumerate(all_cap_feats[i]): + noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1 + cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids( + cap_item, + (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1), + (cap_cu_len, 0, 0), + device, + noise_val, + ) + cap_feats_list.append(cap_out) + cap_pos_list.append(cap_pos) + cap_mask_list.append(cap_mask) + cap_lens.append(cap_len) + cap_noise.extend(cap_nm) + cap_cu_len += len(cap_item) + cap_end_pos.append(cap_cu_len) + cap_cu_len += 2 # for image vae and siglip tokens + + all_cap_out.append(torch.cat(cap_feats_list, dim=0)) + all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0)) + all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0)) + all_cap_len.append(cap_lens) + all_cap_noise_mask.append(cap_noise) + + # Process images + x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], [] + for j, x_item in enumerate(all_x[i]): + noise_val = images_noise_mask[i][j] + if x_item is not None: + x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size) + x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids( + x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val + ) + x_size.append(size) + else: + x_len = SEQ_MULTI_OF + x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device) + x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1) + x_mask = torch.ones(x_len, dtype=torch.bool, device=device) + x_nm = [noise_val] * x_len + x_size.append(None) + x_feats_list.append(x_out) + x_pos_list.append(x_pos) + x_mask_list.append(x_mask) + x_lens.append(x_len) + x_noise.extend(x_nm) + + all_x_out.append(torch.cat(x_feats_list, dim=0)) + all_x_pos_ids.append(torch.cat(x_pos_list, dim=0)) + all_x_pad_mask.append(torch.cat(x_mask_list, dim=0)) + all_x_size.append(x_size) + all_x_len.append(x_lens) + all_x_noise_mask.append(x_noise) + + # Process siglip + if all_siglip_feats[i] is None: + all_sig_len.append([0] * num_images) + all_sig_out.append(None) + else: + sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], [] + for j, sig_item in enumerate(all_siglip_feats[i]): + noise_val = images_noise_mask[i][j] + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C) + sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids( + sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val + ) + # Scale position IDs to match x resolution + if x_size[j] is not None: + sig_pos = sig_pos.float() + sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1) + sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1) + sig_pos = sig_pos.to(torch.int32) + else: + sig_len = SEQ_MULTI_OF + sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device) + sig_pos = ( + self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1) + ) + sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device) + sig_nm = [noise_val] * sig_len + sig_feats_list.append(sig_out) + sig_pos_list.append(sig_pos) + sig_mask_list.append(sig_mask) + sig_lens.append(sig_len) + sig_noise.extend(sig_nm) + + all_sig_out.append(torch.cat(sig_feats_list, dim=0)) + all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0)) + all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0)) + all_sig_len.append(sig_lens) + all_sig_noise_mask.append(sig_noise) + + # Compute x position offsets + all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)] return ( - all_image_out, - all_cap_feats_out, - all_image_size, - all_image_pos_ids, + all_x_out, + all_cap_out, + all_sig_out, + all_x_size, + all_x_pos_ids, all_cap_pos_ids, - all_image_pad_mask, + all_sig_pos_ids, + all_x_pad_mask, all_cap_pad_mask, + all_sig_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_sig_noise_mask, ) + def _prepare_sequence( + self, + feats: List[torch.Tensor], + pos_ids: List[torch.Tensor], + inner_pad_mask: List[torch.Tensor], + pad_token: torch.nn.Parameter, + noise_mask: Optional[List[List[int]]] = None, + device: torch.device = None, + ): + """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask.""" + item_seqlens = [len(f) for f in feats] + max_seqlen = max(item_seqlens) + bsz = len(feats) + + # Pad token + feats_cat = torch.cat(feats, dim=0) + feats_cat[torch.cat(inner_pad_mask)] = pad_token + feats = list(feats_cat.split(item_seqlens, dim=0)) + + # RoPE + freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0)) + + # Pad to batch + feats = pad_sequence(feats, batch_first=True, padding_value=0.0) + freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if noise_mask is not None: + noise_mask_tensor = pad_sequence( + [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask], + batch_first=True, + padding_value=0, + )[:, : feats.shape[1]] + + return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor + + def _build_unified_sequence( + self, + x: torch.Tensor, + x_freqs: torch.Tensor, + x_seqlens: List[int], + x_noise_mask: Optional[List[List[int]]], + cap: torch.Tensor, + cap_freqs: torch.Tensor, + cap_seqlens: List[int], + cap_noise_mask: Optional[List[List[int]]], + siglip: Optional[torch.Tensor], + siglip_freqs: Optional[torch.Tensor], + siglip_seqlens: Optional[List[int]], + siglip_noise_mask: Optional[List[List[int]]], + omni_mode: bool, + device: torch.device, + ): + """Build unified sequence: x, cap, and optionally siglip. + Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip] + """ + bsz = len(x_seqlens) + unified = [] + unified_freqs = [] + unified_noise_mask = [] + + for i in range(bsz): + x_len, cap_len = x_seqlens[i], cap_seqlens[i] + + if omni_mode: + # Omni: [cap, x, siglip] + if siglip is not None and siglip_seqlens is not None: + sig_len = siglip_seqlens[i] + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]])) + unified_freqs.append( + torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device + ) + ) + else: + unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]])) + unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + else: + # Basic: [x, cap] + unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]])) + unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]])) + + # Compute unified seqlens + if omni_mode: + if siglip is not None and siglip_seqlens is not None: + unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)] + else: + unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)] + + max_seqlen = max(unified_seqlens) + + # Pad to batch + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) + + # Attention mask + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 + + # Noise mask + noise_mask_tensor = None + if omni_mode: + noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[ + :, : unified.shape[1] + ] + + return unified, unified_freqs, attn_mask, noise_mask_tensor + def forward( self, - x: List[torch.Tensor], + x: Union[List[torch.Tensor], List[List[torch.Tensor]]], t, - cap_feats: List[torch.Tensor], - patch_size=2, - f_patch_size=1, + cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]], return_dict: bool = True, + controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, + siglip_feats: Optional[List[List[torch.Tensor]]] = None, + image_noise_mask: Optional[List[List[int]]] = None, + patch_size: int = 2, + f_patch_size: int = 1, ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + """ + Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine + -> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify + """ + assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size + omni_mode = isinstance(x[0], list) + device = x[0][-1].device if omni_mode else x[0].device + + if omni_mode: + # Dual embeddings: noisy (t) and clean (t=1) + t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1]) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1]) + adaln_input = None + else: + # Single embedding for all tokens + adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0]) + t_noisy = t_clean = None + + # Patchify + if omni_mode: + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_pad_mask, + cap_pad_mask, + siglip_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + else: + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_pad_mask, + cap_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None + + # X embed & refine + x_seqlens = [len(xi) for xi in x] + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed + x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence( + list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device + ) - bsz = len(x) - device = x[0].device - t = t * self.t_scale - t = self.t_embedder(t) + for layer in self.noise_refiner: + x = ( + self._gradient_checkpointing_func( + layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean + ) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean) + ) - ( - x, - cap_feats, - x_size, - x_pos_ids, - cap_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) - - # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) - - # Match t_embedder output dtype to x for layerwise casting compatibility - adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) - - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - x_freqs_cis = x_freqs_cis[:, : x.shape[1]] - - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: - x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) - else: - for layer in self.noise_refiner: - x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) - - # cap embed & refine - cap_item_seqlens = [len(_) for _ in cap_feats] - cap_max_item_seqlen = max(cap_item_seqlens) - - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list( - self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + # Cap embed & refine + cap_seqlens = [len(ci) for ci in cap_feats] + cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed + cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence( + list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device ) - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + for layer in self.context_refiner: + cap_feats = ( + self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(cap_feats, cap_mask, cap_freqs) + ) - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 + # Siglip embed & refine + siglip_seqlens = siglip_freqs = None + if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_seqlens = [len(si) for si in siglip_feats] + siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed + siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence( + list(siglip_feats.split(siglip_seqlens, dim=0)), + siglip_pos_ids, + siglip_pad_mask, + self.siglip_pad_token, + None, + device, + ) - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.context_refiner: - cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) - else: - for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + for layer in self.siglip_refiner: + siglip_feats = ( + self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(siglip_feats, siglip_mask, siglip_freqs) + ) - # unified - unified = [] - unified_freqs_cis = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) + # Unified sequence + unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence( + x, + x_freqs, + x_seqlens, + x_noise_mask, + cap_feats, + cap_freqs, + cap_seqlens, + cap_noise_mask, + siglip_feats, + siglip_freqs, + siglip_seqlens, + siglip_noise_mask, + omni_mode, + device, + ) - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: - unified = self._gradient_checkpointing_func( - layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input + # Main transformer layers + for layer_idx, layer in enumerate(self.layers): + unified = ( + self._gradient_checkpointing_func( + layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean ) - else: - for layer in self.layers: - unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + if torch.is_grad_enabled() and self.gradient_checkpointing + else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean) + ) + if controlnet_block_samples is not None and layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] - unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) - unified = list(unified.unbind(dim=0)) - x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + unified = ( + self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean + ) + if omni_mode + else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input) + ) - if not return_dict: - return (x,) + # Unpatchify + x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets) - return Transformer2DModelOutput(sample=x) + return (x,) if not return_dict else Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 252b9f33dfe8..823a3d263ea9 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -52,6 +52,13 @@ "FluxKontextAutoBlocks", "FluxKontextModularPipeline", ] + _import_structure["flux2"] = [ + "Flux2AutoBlocks", + "Flux2KleinAutoBlocks", + "Flux2KleinBaseAutoBlocks", + "Flux2ModularPipeline", + "Flux2KleinModularPipeline", + ] _import_structure["qwenimage"] = [ "QwenImageAutoBlocks", "QwenImageModularPipeline", @@ -59,6 +66,12 @@ "QwenImageEditAutoBlocks", "QwenImageEditPlusModularPipeline", "QwenImageEditPlusAutoBlocks", + "QwenImageLayeredModularPipeline", + "QwenImageLayeredAutoBlocks", + ] + _import_structure["z_image"] = [ + "ZImageAutoBlocks", + "ZImageModularPipeline", ] _import_structure["components_manager"] = ["ComponentsManager"] @@ -71,6 +84,13 @@ else: from .components_manager import ComponentsManager from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline + from .flux2 import ( + Flux2AutoBlocks, + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, + Flux2ModularPipeline, + ) from .modular_pipeline import ( AutoPipelineBlocks, BlockState, @@ -87,10 +107,13 @@ QwenImageEditModularPipeline, QwenImageEditPlusAutoBlocks, QwenImageEditPlusModularPipeline, + QwenImageLayeredAutoBlocks, + QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline + from .z_image import ZImageAutoBlocks, ZImageModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index cb7e8fb73697..4a7ea8502c86 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -160,7 +160,10 @@ def __call__(self, hooks, model_id, model, execution_device): if len(hooks) == 0: return [] - current_module_size = model.get_memory_footprint() + try: + current_module_size = model.get_memory_footprint() + except AttributeError: + raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.") device_type = execution_device.type device_module = getattr(torch, device_type, torch.cuda) @@ -321,6 +324,7 @@ class ComponentsManager: "has_hook", "execution_device", "ip_adapter", + "quantization", ] def __init__(self): @@ -353,7 +357,9 @@ def _lookup_ids( ids_by_name.add(component_id) else: ids_by_name = set(components.keys()) - if collection: + if collection and collection not in self.collections: + return set() + elif collection and collection in self.collections: ids_by_collection = set() for component_id, component in components.items(): if component_id in self.collections[collection]: @@ -420,7 +426,8 @@ def add(self, name: str, component: Any, collection: Optional[str] = None): # add component to components manager self.components[component_id] = component - self.added_time[component_id] = time.time() + if is_new_component: + self.added_time[component_id] = time.time() if collection: if collection not in self.collections: @@ -703,7 +710,20 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, if not is_accelerate_available(): raise ImportError("Make sure to install accelerate to use auto_cpu_offload") - # TODO: add a warning if mem_get_info isn't available on `device`. + if device is None: + device = get_device() + if not isinstance(device, torch.device): + device = torch.device(device) + + device_type = device.type + device_module = getattr(torch, device_type, torch.cuda) + if not hasattr(device_module, "mem_get_info"): + raise NotImplementedError( + f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}." + ) + + if device.index is None: + device = torch.device(f"{device.type}:{0}") for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): @@ -711,11 +731,7 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, self.disable_auto_cpu_offload() offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) - if device is None: - device = get_device() - device = torch.device(device) - if device.index is None: - device = torch.device(f"{device.type}:{0}") + all_hooks = [] for name, component in self.components.items(): if isinstance(component, torch.nn.Module): @@ -748,7 +764,6 @@ def disable_auto_cpu_offload(self): self.model_hooks = None self._auto_offload_enabled = False - # YiYi TODO: (1) add quantization info def get_model_info( self, component_id: str, @@ -824,6 +839,17 @@ def get_model_info( if scales: info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) + # Check for quantization + hf_quantizer = getattr(component, "hf_quantizer", None) + if hf_quantizer is not None: + quant_config = hf_quantizer.quantization_config + if hasattr(quant_config, "to_diff_dict"): + info["quantization"] = quant_config.to_diff_dict() + else: + info["quantization"] = quant_config.to_dict() + else: + info["quantization"] = None + # If fields specified, filter info if fields is not None: return {k: v for k, v in info.items() if k in fields} @@ -954,12 +980,16 @@ def format_device(component, info): output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" for name in self.components: info = self.get_model_info(name) - if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")): + if info is not None and ( + info.get("adapters") is not None or info.get("ip_adapter") or info.get("quantization") + ): output += f"\n{name}:\n" if info.get("adapters") is not None: output += f" Adapters: {info['adapters']}\n" if info.get("ip_adapter"): output += " IP-Adapter: Enabled\n" + if info.get("quantization"): + output += f" Quantization: {info['quantization']}\n" return output diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py index 8309eebfeb37..45b1c6bc136f 100644 --- a/src/diffusers/modular_pipelines/flux/inputs.py +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -121,7 +121,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip return components, state -# Adapted from `QwenImageInputsDynamicStep` +# Adapted from `QwenImageAdditionalInputsStep` class FluxInputsDynamicStep(ModularPipelineBlocks): model_name = "flux" diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index a80bc2a5f7a9..bd9b2d1b40c9 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -360,7 +360,7 @@ def description(self): AUTO_BLOCKS = InsertableDict( [ ("text_encoder", FluxTextEncoderStep()), - ("image_encoder", FluxAutoVaeEncoderStep()), + ("vae_encoder", FluxAutoVaeEncoderStep()), ("denoise", FluxCoreDenoiseStep()), ("decode", FluxDecodeStep()), ] @@ -369,7 +369,7 @@ def description(self): AUTO_BLOCKS_KONTEXT = InsertableDict( [ ("text_encoder", FluxTextEncoderStep()), - ("image_encoder", FluxKontextAutoVaeEncoderStep()), + ("vae_encoder", FluxKontextAutoVaeEncoderStep()), ("denoise", FluxKontextCoreDenoiseStep()), ("decode", FluxDecodeStep()), ] diff --git a/src/diffusers/modular_pipelines/flux2/__init__.py b/src/diffusers/modular_pipelines/flux2/__init__.py new file mode 100644 index 000000000000..220ec0c4ab65 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/__init__.py @@ -0,0 +1,116 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = [ + "Flux2TextEncoderStep", + "Flux2RemoteTextEncoderStep", + "Flux2VaeEncoderStep", + ] + _import_structure["before_denoise"] = [ + "Flux2SetTimestepsStep", + "Flux2PrepareLatentsStep", + "Flux2RoPEInputsStep", + "Flux2PrepareImageLatentsStep", + ] + _import_structure["denoise"] = [ + "Flux2LoopDenoiser", + "Flux2LoopAfterDenoiser", + "Flux2DenoiseLoopWrapper", + "Flux2DenoiseStep", + ] + _import_structure["decoders"] = ["Flux2DecodeStep"] + _import_structure["inputs"] = [ + "Flux2ProcessImagesInputStep", + "Flux2TextInputStep", + ] + _import_structure["modular_blocks_flux2"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "REMOTE_AUTO_BLOCKS", + "TEXT2IMAGE_BLOCKS", + "IMAGE_CONDITIONED_BLOCKS", + "Flux2AutoBlocks", + "Flux2AutoVaeEncoderStep", + "Flux2CoreDenoiseStep", + "Flux2VaeEncoderSequentialStep", + ] + _import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"] + _import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .before_denoise import ( + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, + ) + from .decoders import Flux2DecodeStep + from .denoise import ( + Flux2DenoiseLoopWrapper, + Flux2DenoiseStep, + Flux2LoopAfterDenoiser, + Flux2LoopDenoiser, + ) + from .encoders import ( + Flux2RemoteTextEncoderStep, + Flux2TextEncoderStep, + Flux2VaeEncoderStep, + ) + from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, + ) + from .modular_blocks_flux2 import ( + ALL_BLOCKS, + AUTO_BLOCKS, + IMAGE_CONDITIONED_BLOCKS, + REMOTE_AUTO_BLOCKS, + TEXT2IMAGE_BLOCKS, + Flux2AutoBlocks, + Flux2AutoVaeEncoderStep, + Flux2CoreDenoiseStep, + Flux2VaeEncoderSequentialStep, + ) + from .modular_blocks_flux2_klein import ( + Flux2KleinAutoBlocks, + Flux2KleinBaseAutoBlocks, + ) + from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/flux2/before_denoise.py b/src/diffusers/modular_pipelines/flux2/before_denoise.py new file mode 100644 index 000000000000..d5bab16586d7 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/before_denoise.py @@ -0,0 +1,592 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import numpy as np +import torch + +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + """Compute empirical mu for Flux2 timestep scheduling.""" + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Flux2SetTimestepsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("latents", type_hint=torch.Tensor), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + scheduler = components.scheduler + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + vae_scale_factor = components.vae_scale_factor + + latent_height = 2 * (int(height) // (vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + num_inference_steps = block_state.num_inference_steps + sigmas = block_state.sigmas + timesteps = block_state.timesteps + + if timesteps is None and sigmas is None: + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas: + sigmas = None + + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps, + device, + timesteps=timesteps, + sigmas=sigmas, + mu=mu, + ) + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ), + OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"), + ] + + @staticmethod + def check_inputs(components, block_state): + vae_scale_factor = components.vae_scale_factor + if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or ( + block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + def _prepare_latent_ids(latents: torch.Tensor): + """ + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents: Latent tensor of shape (B, C, H, W) + + Returns: + Position IDs tensor of shape (B, H*W, 4) + """ + batch_size, _, height, width = latents.shape + + t = torch.arange(1) + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) + + latent_ids = torch.cartesian_prod(t, h, w, l) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @staticmethod + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.device = components._execution_device + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + latents = self.prepare_latents( + components, + batch_size, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(block_state.device) + + latents = self._pack_latents(latents) + + block_state.latents = latents + block_state.latent_ids = latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="prompt_embeds", required=True), + InputParam(name="negative_prompt_embeds", required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.", + ), + OutputParam( + name="negative_txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.", + ), + ] + + @staticmethod + def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None): + """Prepare 4D position IDs for text tokens.""" + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + seq_l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, seq_l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device = prompt_embeds.device + + block_state.txt_ids = self._prepare_text_ids(prompt_embeds) + block_state.txt_ids = block_state.txt_ids.to(device) + + block_state.negative_txt_ids = None + if block_state.negative_prompt_embeds is not None: + block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds) + block_state.negative_txt_ids = block_state.negative_txt_ids.to(device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareImageLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares image latents and their position IDs for Flux2 image conditioning." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image_latents", type_hint=List[torch.Tensor]), + InputParam("batch_size", required=True, type_hint=int), + InputParam("num_images_per_prompt", default=1, type_hint=int), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning", + ), + OutputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents", + ), + ] + + @staticmethod + def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10): + """ + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + Args: + image_latents: A list of image latent feature tensors of shape (1, C, H, W). + scale: Factor used to define the time separation between latents. + + Returns: + Combined coordinate tensor of shape (1, N_total, 4) + """ + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _pack_latents(latents): + """Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)""" + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + return latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image_latents = block_state.image_latents + + if image_latents is None: + block_state.image_latents = None + block_state.image_latent_ids = None + self.set_block_state(state, block_state) + + return components, state + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + image_latent_ids = self._prepare_image_ids(image_latents) + + packed_latents = [] + for latent in image_latents: + packed = self._pack_latents(latent) + packed = packed.squeeze(0) + packed_latents.append(packed) + + image_latents = torch.cat(packed_latents, dim=0) + image_latents = image_latents.unsqueeze(0) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + block_state.image_latents = image_latents + block_state.image_latent_ids = image_latent_ids + + self.set_block_state(state, block_state) + return components, state + + +class Flux2PrepareGuidanceStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that prepares the guidance scale tensor for Flux2 inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("guidance_scale", default=4.0), + InputParam("num_images_per_prompt", default=1), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(batch_size) + block_state.guidance = guidance + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/decoders.py b/src/diffusers/modular_pipelines/flux2/decoders.py new file mode 100644 index 000000000000..c79375072037 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/decoders.py @@ -0,0 +1,183 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLFlux2 +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2UnpackLatentsStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Step that unpacks the latents from the denoising step" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="Position IDs for the latents, used for unpacking", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The denoise latents from denoising step, unpacked with position IDs.", + ) + ] + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor: + """ + Unpack latents using position IDs to scatter tokens into place. + + Args: + x: Packed latents tensor of shape (B, seq_len, C) + x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates + + Returns: + Unpacked latents tensor of shape (B, C, H, W) + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + latent_ids = block_state.latent_ids + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + block_state.latents = latents + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DecodeStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @staticmethod + def _unpatchify_latents(latents): + """Convert patchified latents back to regular format.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + latents = block_state.latents + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + + latents = self._unpatchify_latents(latents) + + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/denoise.py b/src/diffusers/modular_pipelines/flux2/denoise.py new file mode 100644 index 000000000000..a726959a29e2 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/denoise.py @@ -0,0 +1,509 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2LoopDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "guidance", + required=True, + type_hint=torch.Tensor, + description="Guidance scale as a tensor", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Mistral3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=block_state.guidance, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# same as Flux2LoopDenoiser but guidance=None +class Flux2KleinLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("transformer", Flux2Transformer2DModel)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=block_state.prompt_embeds, + txt_ids=block_state.txt_ids, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1)] + block_state.noise_pred = noise_pred + + return components, block_state + + +# support CFG for Flux2-Klein base model +class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", Flux2Transformer2DModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents for Flux2. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("joint_attention_kwargs"), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise. Shape: (B, seq_len, C)", + ), + InputParam( + "image_latents", + type_hint=torch.Tensor, + description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", + ), + InputParam( + "image_latent_ids", + type_hint=torch.Tensor, + description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings from Qwen3", + ), + InputParam( + "negative_prompt_embeds", + required=False, + type_hint=torch.Tensor, + description="Negative text embeddings from Qwen3", + ), + InputParam( + "txt_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for text tokens (T, H, W, L)", + ), + InputParam( + "negative_txt_ids", + required=False, + type_hint=torch.Tensor, + description="4D position IDs for negative text tokens (T, H, W, L)", + ), + InputParam( + "latent_ids", + required=True, + type_hint=torch.Tensor, + description="4D position IDs for latent tokens (T, H, W, L)", + ), + ] + + @torch.no_grad() + def __call__( + self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latents = block_state.latents + latent_model_input = latents.to(components.transformer.dtype) + img_ids = block_state.latent_ids + + image_latents = getattr(block_state, "image_latents", None) + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) + image_latent_ids = block_state.image_latent_ids + img_ids = torch.cat([img_ids, image_latent_ids], dim=1) + + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "txt_ids": ( + getattr(block_state, "txt_ids", None), + getattr(block_state, "negative_txt_ids", None), + ), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + img_ids=img_ids, + joint_attention_kwargs=block_state.joint_attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)] + components.guider.cleanup_models(components.transformer) + + # perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class Flux2LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "flux2" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that updates the latents after denoising. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `Flux2DenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> List[str]: + return [InputParam("generator")] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoises the latents over `timesteps`. " + "The specific steps within each iteration can be customized with `sub_blocks` attribute" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", Flux2Transformer2DModel), + ] + + @property + def loop_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process.", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self.set_block_state(state, block_state) + return components, state + + +class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2LoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) + + +class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper): + block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents for Flux2. \n" + "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `Flux2KleinBaseLoopDenoiser`\n" + " - `Flux2LoopAfterDenoiser`\n" + "This block supports both text-to-image and image-conditioned generation." + ) diff --git a/src/diffusers/modular_pipelines/flux2/encoders.py b/src/diffusers/modular_pipelines/flux2/encoders.py new file mode 100644 index 000000000000..265fb387367c --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/encoders.py @@ -0,0 +1,609 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLFlux2 +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def format_text_input(prompts: List[str], system_message: str = None): + """Format prompts for Mistral3 chat template.""" + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2TextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + # fmt: off + DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Mistral3ForConditionalGeneration), + ComponentSpec("tokenizer", AutoProcessor), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from Mistral3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + def _get_mistral_3_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: Tuple[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_mistral_3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=block_state.device, + max_sequence_length=block_state.max_sequence_length, + system_message=self.DEFAULT_SYSTEM_MESSAGE, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2RemoteTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using a remote API endpoint" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from remote API used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + import io + + import requests + from huggingface_hub import get_token + + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + response = requests.post( + self.REMOTE_URL, + json={"prompt": prompt}, + headers={ + "Authorization": f"Bearer {get_token()}", + "Content-Type": "application/json", + }, + ) + response.raise_for_status() + + block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True) + block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=True), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3ForCausalLM), + ComponentSpec("tokenizer", Qwen2TokenizerFast), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec(name="is_distilled", default=False), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("max_sequence_length", type_hint=int, default=512, required=False), + InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Text embeddings from qwen3 used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Negative text embeddings from qwen3 used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + prompt = block_state.prompt + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @torch.no_grad() + def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + + block_state.prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + + if components.requires_unconditional_embeds: + negative_prompt = [""] * len(prompt) + block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + device=device, + max_sequence_length=block_state.max_sequence_length, + hidden_states_layers=block_state.text_encoder_out_layers, + ) + else: + block_state.negative_prompt_embeds = None + + self.set_block_state(state, block_state) + return components, state + + +class Flux2VaeEncoderStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKLFlux2)] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("condition_images", type_hint=List[torch.Tensor]), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=List[torch.Tensor], + description="List of latent representations for each reference image", + ), + ] + + @staticmethod + def _patchify_latents(latents): + """Convert latents to patchified format for Flux2.""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator): + """Encode a single image using Flux2 VAE with batch norm normalization.""" + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps) + latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + condition_images = block_state.condition_images + + if condition_images is None: + return components, state + + device = components._execution_device + dtype = components.vae.dtype + + image_latents = [] + for image in condition_images: + image = image.to(device=device, dtype=dtype) + latent = self._encode_vae_image( + vae=components.vae, + image=image, + generator=block_state.generator, + ) + image_latents.append(latent) + + block_state.image_latents = image_latents + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/inputs.py b/src/diffusers/modular_pipelines/flux2/inputs.py new file mode 100644 index 000000000000..3463de1999c6 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/inputs.py @@ -0,0 +1,244 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + +from ...configuration_utils import FrozenDict +from ...pipelines.flux2.image_processor import Flux2ImageProcessor +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import Flux2ModularPipeline + + +logger = logging.get_logger(__name__) + + +class Flux2TextInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2KleinBaseTextInputStep(ModularPipelineBlocks): + model_name = "flux2-klein" + + @property + def description(self) -> str: + return ( + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + required=False, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Negative text embeddings used to guide the image generation", + ), + ] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class Flux2ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "flux2" + + @property + def description(self) -> str: + return "Image preprocess step for Flux2. Validates and preprocesses reference images." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + Flux2ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image"), + InputParam("height"), + InputParam("width"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])] + + @torch.no_grad() + def __call__(self, components: Flux2ModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + images = block_state.image + + if images is None: + block_state.condition_images = None + self.set_block_state(state, block_state) + return components, state + + if not isinstance(images, list): + images = [images] + + condition_images = [] + for img in images: + components.image_processor.check_image_input(img) + + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = components.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = components.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + condition_img = components.image_processor.preprocess( + img, height=image_height, width=image_width, resize_mode="crop" + ) + condition_images.append(condition_img) + + if block_state.height is None: + block_state.height = image_height + if block_state.width is None: + block_state.width = image_width + + block_state.condition_images = condition_images + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py new file mode 100644 index 000000000000..41a0ff7dee28 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2.py @@ -0,0 +1,206 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import PIL.Image +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2PrepareGuidanceStep, + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2DenoiseStep +from .encoders import ( + Flux2RemoteTextEncoderStep, + Flux2TextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +Flux2VaeEncoderBlocks = InsertableDict( + [ + ("preprocess", Flux2ProcessImagesInputStep()), + ("encode", Flux2VaeEncoderStep()), + ] +) + + +class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2VaeEncoderBlocks.values() + block_names = Flux2VaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning." + + +class Flux2AutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2VaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +Flux2CoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +class Flux2CoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2CoreDenoiseBlocks.values() + block_names = Flux2CoreDenoiseBlocks.keys() + + @property + def description(self): + return ( + "Core denoise step that performs the denoising process for Flux2-dev.\n" + " - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + + +REMOTE_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2RemoteTextEncoderStep()), + ("vae_encoder", Flux2AutoVaeEncoderStep()), + ("denoise", Flux2CoreDenoiseStep()), + ("decode", Flux2DecodeStep()), + ] +) + + +class Flux2AutoBlocks(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n" + "- For text-to-image generation, all you need to provide is `prompt`.\n" + "- For image-conditioned generation, you need to provide `image` (list of PIL images)." + ) + + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + + +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ("decode", Flux2DecodeStep()), + ] +) + +IMAGE_CONDITIONED_BLOCKS = InsertableDict( + [ + ("text_encoder", Flux2TextEncoderStep()), + ("text_input", Flux2TextInputStep()), + ("preprocess_images", Flux2ProcessImagesInputStep()), + ("vae_encoder", Flux2VaeEncoderStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_guidance", Flux2PrepareGuidanceStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2DenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ("decode", Flux2DecodeStep()), + ] +) + +ALL_BLOCKS = { + "text2image": TEXT2IMAGE_BLOCKS, + "image_conditioned": IMAGE_CONDITIONED_BLOCKS, + "auto": AUTO_BLOCKS, + "remote": REMOTE_AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py new file mode 100644 index 000000000000..984832d77be5 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_blocks_flux2_klein.py @@ -0,0 +1,232 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import PIL.Image +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + Flux2KleinBaseRoPEInputsStep, + Flux2PrepareImageLatentsStep, + Flux2PrepareLatentsStep, + Flux2RoPEInputsStep, + Flux2SetTimestepsStep, +) +from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep +from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep +from .encoders import ( + Flux2KleinBaseTextEncoderStep, + Flux2KleinTextEncoderStep, + Flux2VaeEncoderStep, +) +from .inputs import ( + Flux2KleinBaseTextInputStep, + Flux2ProcessImagesInputStep, + Flux2TextInputStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +################ +# VAE encoder +################ + +Flux2KleinVaeEncoderBlocks = InsertableDict( + [ + ("preprocess", Flux2ProcessImagesInputStep()), + ("encode", Flux2VaeEncoderStep()), + ] +) + + +class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks): + model_name = "flux2" + + block_classes = Flux2KleinVaeEncoderBlocks.values() + block_names = Flux2KleinVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations." + + +class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [Flux2KleinVaeEncoderSequentialStep] + block_names = ["img_conditioning"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "VAE encoder step that encodes the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for image conditioning tasks.\n" + " - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n" + " - If `image` is not provided, step will be skipped." + ) + + +### +### Core denoise +### + +Flux2KleinCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2TextInputStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2RoPEInputsStep()), + ("denoise", Flux2KleinDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + + block_classes = Flux2KleinCoreDenoiseBlocks.values() + block_names = Flux2KleinCoreDenoiseBlocks.keys() + + @property + def description(self): + return ( + "Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n" + " - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n" + " - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] + + +Flux2KleinBaseCoreDenoiseBlocks = InsertableDict( + [ + ("input", Flux2KleinBaseTextInputStep()), + ("prepare_latents", Flux2PrepareLatentsStep()), + ("prepare_image_latents", Flux2PrepareImageLatentsStep()), + ("set_timesteps", Flux2SetTimestepsStep()), + ("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()), + ("denoise", Flux2KleinBaseDenoiseStep()), + ("after_denoise", Flux2UnpackLatentsStep()), + ] +) + + +class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = Flux2KleinBaseCoreDenoiseBlocks.values() + block_names = Flux2KleinBaseCoreDenoiseBlocks.keys() + + @property + def description(self): + return "Core denoise step that performs the denoising process for Flux2-Klein (base model)." + return ( + "Core denoise step that performs the denoising process for Flux2-Klein (base model).\n" + " - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n" + " - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n" + " - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n" + " - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n" + " - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n" + " - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n" + " - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The latents from the denoising step.", + ) + ] + + +### +### Auto blocks +### +class Flux2KleinAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [ + Flux2KleinTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] + + +class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks): + model_name = "flux2-klein" + block_classes = [ + Flux2KleinBaseTextEncoderStep(), + Flux2KleinAutoVaeEncoderStep(), + Flux2KleinBaseCoreDenoiseStep(), + Flux2DecodeStep(), + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + + @property + def description(self): + return ( + "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n" + + " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n" + + " - for text-to-image generation, all you need to provide is `prompt`.\n" + ) + + @property + def outputs(self): + return [ + OutputParam( + name="images", + type_hint=List[PIL.Image.Image], + description="The images from the decoding step.", + ) + ] diff --git a/src/diffusers/modular_pipelines/flux2/modular_pipeline.py b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py new file mode 100644 index 000000000000..29fbeba07c24 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux2/modular_pipeline.py @@ -0,0 +1,112 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional + +from ...loaders import Flux2LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + +class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin): + """ + A ModularPipeline for Flux2-Klein. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "Flux2KleinBaseAutoBlocks" + + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]: + return "Flux2KleinAutoBlocks" + else: + return "Flux2KleinBaseAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 32 + if getattr(self, "transformer", None): + num_channels_latents = self.transformer.config.in_channels // 4 + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + if hasattr(self.config, "is_distilled") and self.config.is_distilled: + return False + + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py index a405aebee221..35241023f3fc 100644 --- a/src/diffusers/modular_pipelines/mellon_node_utils.py +++ b/src/diffusers/modular_pipelines/mellon_node_utils.py @@ -1,447 +1,862 @@ +import copy import json import logging import os # Simple typed wrapper for parameter overrides from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union -from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub import create_repo, hf_hub_download, upload_file from huggingface_hub.utils import ( EntryNotFoundError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError, - validate_hf_hub_args, ) -from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, PushToHubMixin, extract_commit_hash -from .modular_pipeline import ModularPipelineBlocks +from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT +from .modular_pipeline_utils import InputParam, OutputParam logger = logging.getLogger(__name__) -SUPPORTED_NODE_TYPES = {"controlnet", "vae_encoder", "denoise", "text_encoder", "decoder"} +def _name_to_label(name: str) -> str: + """Convert snake_case name to Title Case label.""" + return name.replace("_", " ").title() -# Mellon Input Parameters (runtime parameters, not models) -MELLON_INPUT_PARAMS = { - # controlnet +# Template definitions for standard diffuser pipeline parameters +MELLON_PARAM_TEMPLATES = { + # Image I/O + "image": {"label": "Image", "type": "image", "display": "input", "required_block_params": ["image"]}, + "images": {"label": "Images", "type": "image", "display": "output", "required_block_params": ["images"]}, "control_image": { "label": "Control Image", "type": "image", "display": "input", + "required_block_params": ["control_image"], }, - "controlnet_conditioning_scale": { - "label": "Scale", - "type": "float", - "default": 0.5, - "min": 0, - "max": 1, - }, - "control_guidance_end": { - "label": "End", - "type": "float", - "default": 1.0, - "min": 0, - "max": 1, - }, - "control_guidance_start": { - "label": "Start", - "type": "float", - "default": 0.0, - "min": 0, - "max": 1, - }, - "controlnet": { - "label": "Controlnet", - "type": "custom_controlnet", + # Latents + "latents": {"label": "Latents", "type": "latents", "display": "input", "required_block_params": ["latents"]}, + "image_latents": { + "label": "Image Latents", + "type": "latents", "display": "input", + "required_block_params": ["image_latents"], }, - "embeddings": { - "label": "Text Embeddings", + "first_frame_latents": { + "label": "First Frame Latents", + "type": "latents", "display": "input", - "type": "embeddings", + "required_block_params": ["first_frame_latents"], }, - "image": { - "label": "Image", - "type": "image", + "latents_preview": {"label": "Latents Preview", "type": "latent", "display": "output"}, + # Image Latents with Strength + "image_latents_with_strength": { + "name": "image_latents", # name is not same as template key + "label": "Image Latents", + "type": "latents", "display": "input", + "onChange": {"false": ["height", "width"], "true": ["strength"]}, + "required_block_params": ["image_latents", "strength"], }, - "negative_prompt": { - "label": "Negative Prompt", - "type": "string", - "default": "", - "display": "textarea", + # Embeddings + "embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "output"}, + "image_embeds": { + "label": "Image Embeddings", + "type": "image_embeds", + "display": "output", + "required_block_params": ["image_embeds"], }, + # Text inputs "prompt": { "label": "Prompt", "type": "string", + "display": "textarea", "default": "", + "required_block_params": ["prompt"], + }, + "negative_prompt": { + "label": "Negative Prompt", + "type": "string", "display": "textarea", + "default": "", + "required_block_params": ["negative_prompt"], }, + # Numeric params "guidance_scale": { "label": "Guidance Scale", "type": "float", "display": "slider", - "default": 5, + "default": 5.0, "min": 1.0, "max": 30.0, "step": 0.1, }, + "strength": { + "label": "Strength", + "type": "float", + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "required_block_params": ["strength"], + }, "height": { "label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8, + "required_block_params": ["height"], }, - "image_latents": { - "label": "Image Latents", - "type": "latents", - "display": "input", - "onChange": {False: ["height", "width"], True: ["strength"]}, + "width": { + "label": "Width", + "type": "int", + "default": 1024, + "min": 64, + "step": 8, + "required_block_params": ["width"], }, - "latents": { - "label": "Latents", - "type": "latents", - "display": "input", + "seed": { + "label": "Seed", + "type": "int", + "default": 0, + "min": 0, + "max": 4294967295, + "display": "random", + "required_block_params": ["generator"], }, "num_inference_steps": { "label": "Steps", "type": "int", - "display": "slider", "default": 25, "min": 1, "max": 100, + "display": "slider", + "required_block_params": ["num_inference_steps"], }, - "seed": { - "label": "Seed", + "num_frames": { + "label": "Frames", "type": "int", - "display": "random", - "default": 0, - "min": 0, - "max": 4294967295, + "default": 81, + "min": 1, + "max": 480, + "display": "slider", + "required_block_params": ["num_frames"], }, - "strength": { - "label": "Strength", + "layers": { + "label": "Layers", + "type": "int", + "default": 4, + "min": 1, + "max": 10, + "display": "slider", + "required_block_params": ["layers"], + }, + # ControlNet + "controlnet_conditioning_scale": { + "label": "Controlnet Conditioning Scale", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, + "required_block_params": ["controlnet_conditioning_scale"], }, - "width": { - "label": "Width", - "type": "int", - "default": 1024, - "min": 64, - "step": 8, + "control_guidance_start": { + "label": "Control Guidance Start", + "type": "float", + "default": 0.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "required_block_params": ["control_guidance_start"], }, - "ip_adapter": { - "label": "IP Adapter", - "type": "custom_ip_adapter", - "display": "input", + "control_guidance_end": { + "label": "Control Guidance End", + "type": "float", + "default": 1.0, + "min": 0.0, + "max": 1.0, + "step": 0.01, + "required_block_params": ["control_guidance_end"], }, -} - -# Mellon Model Parameters (diffusers_auto_model types) -MELLON_MODEL_PARAMS = { - "scheduler": { - "label": "Scheduler", + # Video + "videos": {"label": "Videos", "type": "video", "display": "output", "required_block_params": ["videos"]}, + # Models + "vae": {"label": "VAE", "type": "diffusers_auto_model", "display": "input", "required_block_params": ["vae"]}, + "image_encoder": { + "label": "Image Encoder", + "type": "diffusers_auto_model", "display": "input", + "required_block_params": ["image_encoder"], + }, + "unet": {"label": "Denoise Model", "type": "diffusers_auto_model", "display": "input"}, + "scheduler": {"label": "Scheduler", "type": "diffusers_auto_model", "display": "input"}, + "controlnet": { + "label": "ControlNet Model", "type": "diffusers_auto_model", + "display": "input", + "required_block_params": ["controlnet"], }, "text_encoders": { "label": "Text Encoders", "type": "diffusers_auto_models", "display": "input", + "required_block_params": ["text_encoder"], }, - "unet": { - "label": "Unet", + # Bundles/Custom + "controlnet_bundle": { + "label": "ControlNet", + "type": "custom_controlnet", "display": "input", - "type": "diffusers_auto_model", - "onSignal": { - "action": "signal", - "target": "guider", - }, + "required_block_params": "controlnet_image", }, + "ip_adapter": {"label": "IP Adapter", "type": "custom_ip_adapter", "display": "input"}, "guider": { "label": "Guider", - "display": "input", "type": "custom_guider", - "onChange": {False: ["guidance_scale"], True: []}, - }, - "vae": { - "label": "VAE", - "display": "input", - "type": "diffusers_auto_model", - }, - "controlnet": { - "label": "Controlnet Model", - "type": "diffusers_auto_model", "display": "input", + "onChange": {False: ["guidance_scale"], True: []}, }, + "doc": {"label": "Doc", "type": "string", "display": "output"}, } -# Mellon Output Parameters (display = "output") -MELLON_OUTPUT_PARAMS = { - "embeddings": { - "label": "Text Embeddings", - "display": "output", - "type": "embeddings", - }, - "images": { - "label": "Images", - "type": "image", - "display": "output", - }, - "image_latents": { - "label": "Image Latents", - "type": "latents", - "display": "output", - }, - "latents": { - "label": "Latents", - "type": "latents", - "display": "output", - }, - "latents_preview": { - "label": "Latents Preview", - "display": "output", - "type": "latent", - }, - "controlnet_out": { - "label": "Controlnet", - "display": "output", - "type": "controlnet", - }, -} +class MellonParamMeta(type): + """Metaclass that enables MellonParam.template_name(**overrides) syntax.""" -# Default param selections per supported node_type -# from MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. -NODE_TYPE_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - "vae", - ], - "outputs": [ - "controlnet", - ], - "block_names": ["controlnet_vae_encoder"], - }, + def __getattr__(cls, name: str): + if name in MELLON_PARAM_TEMPLATES: + + def factory(default=None, **overrides): + template = MELLON_PARAM_TEMPLATES[name] + # Use template's name if specified, otherwise use the key + params = {"name": template.get("name", name), **template, **overrides} + if default is not None: + params["default"] = default + return cls(**params) + + return factory + + raise AttributeError(f"type object 'MellonParam' has no attribute '{name}'") + + +@dataclass(frozen=True) +class MellonParam(metaclass=MellonParamMeta): + """ + Parameter definition for Mellon nodes. + + Usage: + ```python + # From template (standard diffuser params) + MellonParam.seed() + MellonParam.prompt(default="a cat") + MellonParam.latents(display="output") + + # Generic inputs (for custom blocks) + MellonParam.Input.slider("my_scale", default=1.0, min=0.0, max=2.0) + MellonParam.Input.dropdown("mode", options=["fast", "slow"]) + + # Generic outputs + MellonParam.Output.image("result_images") + + # Fully custom + MellonParam(name="custom", label="Custom", type="float", default=0.5) + ``` + """ + + name: str + label: str + type: str + display: Optional[str] = None + default: Any = None + min: Optional[float] = None + max: Optional[float] = None + step: Optional[float] = None + options: Any = None + value: Any = None + fieldOptions: Optional[Dict[str, Any]] = None + onChange: Any = None + onSignal: Any = None + required_block_params: Optional[Union[str, List[str]]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dict for Mellon schema, excluding None values and internal fields.""" + data = asdict(self) + return {k: v for k, v in data.items() if v is not None and k not in ("name", "required_block_params")} + + # ========================================================================= + # Input: Generic input parameter factories (for custom blocks) + # ========================================================================= + class Input: + """input UI elements for custom blocks.""" + + @classmethod + def image(cls, name: str) -> "MellonParam": + """image input.""" + return MellonParam(name=name, label=_name_to_label(name), type="image", display="input") + + @classmethod + def textbox(cls, name: str, default: str = "") -> "MellonParam": + """text input as textarea.""" + return MellonParam( + name=name, label=_name_to_label(name), type="string", display="textarea", default=default + ) + + @classmethod + def dropdown(cls, name: str, options: List[str] = None, default: str = None) -> "MellonParam": + """dropdown selection.""" + if options and not default: + default = options[0] + if not default: + default = "" + if not options: + options = [default] + return MellonParam(name=name, label=_name_to_label(name), type="string", options=options, value=default) + + @classmethod + def slider( + cls, name: str, default: float = 0, min: float = None, max: float = None, step: float = None + ) -> "MellonParam": + """slider input.""" + is_float = isinstance(default, float) or (step is not None and isinstance(step, float)) + param_type = "float" if is_float else "int" + if min is None: + min = default + if max is None: + max = default + if step is None: + step = 0.01 if is_float else 1 + return MellonParam( + name=name, + label=_name_to_label(name), + type=param_type, + display="slider", + default=default, + min=min, + max=max, + step=step, + ) + + @classmethod + def number( + cls, name: str, default: float = 0, min: float = None, max: float = None, step: float = None + ) -> "MellonParam": + """number input (no slider).""" + is_float = isinstance(default, float) or (step is not None and isinstance(step, float)) + param_type = "float" if is_float else "int" + return MellonParam( + name=name, label=_name_to_label(name), type=param_type, default=default, min=min, max=max, step=step + ) + + @classmethod + def seed(cls, name: str = "seed", default: int = 0) -> "MellonParam": + """seed input with randomize button.""" + return MellonParam( + name=name, + label=_name_to_label(name), + type="int", + display="random", + default=default, + min=0, + max=4294967295, + ) + + @classmethod + def checkbox(cls, name: str, default: bool = False) -> "MellonParam": + """boolean checkbox.""" + return MellonParam(name=name, label=_name_to_label(name), type="boolean", value=default) + + @classmethod + def custom_type(cls, name: str, type: str) -> "MellonParam": + """custom type input for node connections.""" + return MellonParam(name=name, label=_name_to_label(name), type=type, display="input") + + @classmethod + def model(cls, name: str) -> "MellonParam": + """model input for diffusers components.""" + return MellonParam(name=name, label=_name_to_label(name), type="diffusers_auto_model", display="input") + + # ========================================================================= + # Output: Generic output parameter factories (for custom blocks) + # ========================================================================= + class Output: + """output UI elements for custom blocks.""" + + @classmethod + def image(cls, name: str) -> "MellonParam": + """image output.""" + return MellonParam(name=name, label=_name_to_label(name), type="image", display="output") + + @classmethod + def video(cls, name: str) -> "MellonParam": + """video output.""" + return MellonParam(name=name, label=_name_to_label(name), type="video", display="output") + + @classmethod + def text(cls, name: str) -> "MellonParam": + """text output.""" + return MellonParam(name=name, label=_name_to_label(name), type="string", display="output") + + @classmethod + def custom_type(cls, name: str, type: str) -> "MellonParam": + """custom type output for node connections.""" + return MellonParam(name=name, label=_name_to_label(name), type=type, display="output") + + @classmethod + def model(cls, name: str) -> "MellonParam": + """model output for diffusers components.""" + return MellonParam(name=name, label=_name_to_label(name), type="diffusers_auto_model", display="output") + + +def input_param_to_mellon_param(input_param: "InputParam") -> MellonParam: + """ + Convert an InputParam to a MellonParam using metadata. + + Args: + input_param: An InputParam with optional metadata containing either: + - {"mellon": ""} for simple types (image, textbox, slider, etc.) + - {"mellon": MellonParam(...)} for full control over UI configuration + + Returns: + MellonParam instance + """ + name = input_param.name + metadata = input_param.metadata + mellon_value = metadata.get("mellon") if metadata else None + default = input_param.default + + # If it's already a MellonParam, return it directly + if isinstance(mellon_value, MellonParam): + return mellon_value + + mellon_type = mellon_value + + if mellon_type == "image": + return MellonParam.Input.image(name) + elif mellon_type == "textbox": + return MellonParam.Input.textbox(name, default=default or "") + elif mellon_type == "dropdown": + return MellonParam.Input.dropdown(name, default=default or "") + elif mellon_type == "slider": + return MellonParam.Input.slider(name, default=default or 0) + elif mellon_type == "number": + return MellonParam.Input.number(name, default=default or 0) + elif mellon_type == "seed": + return MellonParam.Input.seed(name, default=default or 0) + elif mellon_type == "checkbox": + return MellonParam.Input.checkbox(name, default=default or False) + elif mellon_type == "model": + return MellonParam.Input.model(name) + else: + # None or unknown -> custom + return MellonParam.Input.custom_type(name, type="custom") + + +def output_param_to_mellon_param(output_param: "OutputParam") -> MellonParam: + """ + Convert an OutputParam to a MellonParam using metadata. + + Args: + output_param: An OutputParam with optional metadata={"mellon": ""} where type is one of: + image, video, text, model. If metadata is None or unknown, maps to "custom". + + Returns: + MellonParam instance + """ + name = output_param.name + metadata = output_param.metadata + mellon_type = metadata.get("mellon") if metadata else None + + if mellon_type == "image": + return MellonParam.Output.image(name) + elif mellon_type == "video": + return MellonParam.Output.video(name) + elif mellon_type == "text": + return MellonParam.Output.text(name) + elif mellon_type == "model": + return MellonParam.Output.model(name) + else: + # None or unknown -> custom + return MellonParam.Output.custom_type(name, type="custom") + + +DEFAULT_NODE_SPECS = { + "controlnet": None, "denoise": { "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - # custom adapters coming in as inputs - "controlnet", - # ip_adapter is optional and custom; include if available - "ip_adapter", + MellonParam.embeddings(display="input"), + MellonParam.width(), + MellonParam.height(), + MellonParam.seed(), + MellonParam.num_inference_steps(), + MellonParam.num_frames(), + MellonParam.guidance_scale(), + MellonParam.strength(), + MellonParam.image_latents_with_strength(), + MellonParam.image_latents(), + MellonParam.first_frame_latents(), + MellonParam.controlnet_bundle(display="input"), ], "model_inputs": [ - "unet", - "guider", - "scheduler", + MellonParam.unet(), + MellonParam.guider(), + MellonParam.scheduler(), ], "outputs": [ - "latents", - "latents_preview", + MellonParam.latents(display="output"), + MellonParam.latents_preview(), + MellonParam.doc(), ], - "block_names": ["denoise"], + "required_inputs": ["embeddings"], + "required_model_inputs": ["unet", "scheduler"], + "block_name": "denoise", }, "vae_encoder": { "inputs": [ - "image", - "width", - "height", + MellonParam.image(), ], "model_inputs": [ - "vae", + MellonParam.vae(), ], "outputs": [ - "image_latents", + MellonParam.image_latents(display="output"), + MellonParam.doc(), ], - "block_names": ["vae_encoder"], + "required_inputs": ["image"], + "required_model_inputs": ["vae"], + "block_name": "vae_encoder", }, "text_encoder": { "inputs": [ - "prompt", - "negative_prompt", - # optional image prompt input supported in embeddings node - "image", + MellonParam.prompt(), + MellonParam.negative_prompt(), ], "model_inputs": [ - "text_encoders", + MellonParam.text_encoders(), ], "outputs": [ - "embeddings", + MellonParam.embeddings(display="output"), + MellonParam.doc(), ], - "block_names": ["text_encoder"], + "required_inputs": ["prompt"], + "required_model_inputs": ["text_encoders"], + "block_name": "text_encoder", }, "decoder": { "inputs": [ - "latents", + MellonParam.latents(display="input"), ], "model_inputs": [ - "vae", + MellonParam.vae(), ], "outputs": [ - "images", + MellonParam.images(), + MellonParam.videos(), + MellonParam.doc(), ], - "block_names": ["decode"], + "required_inputs": ["latents"], + "required_model_inputs": ["vae"], + "block_name": "decode", }, } -@dataclass(frozen=True) -class MellonParam: - name: str - label: str - type: str - display: Optional[str] = None - default: Any = None - min: Optional[float] = None - max: Optional[float] = None - step: Optional[float] = None - options: Any = None - value: Any = None - fieldOptions: Optional[Dict[str, Any]] = None - onChange: Any = None - onSignal: Any = None - _map_to_input: Any = None # the block input name this parameter maps to +def mark_required(label: str, marker: str = " *") -> str: + """Add required marker to label if not already present.""" + if label.endswith(marker): + return label + return f"{label}{marker}" - def to_dict(self) -> Dict[str, Any]: - data = asdict(self) - return {k: v for k, v in data.items() if not k.startswith("_") and v is not None} +def node_spec_to_mellon_dict(node_spec: Dict[str, Any], node_type: str) -> Dict[str, Any]: + """ + Convert a node spec dict into Mellon format. + + A node spec is how we define a Mellon diffusers node in code. This function converts it into the `params` map + format that Mellon UI expects. + + The `params` map is a dict where keys are parameter names and values are UI configuration: + ```python + {"seed": {"label": "Seed", "type": "int", "default": 0}} + ``` + + For Modular Mellon nodes, we need to distinguish: + - `inputs`: Pipeline inputs (e.g., seed, prompt, image) + - `model_inputs`: Model components (e.g., unet, vae, scheduler) + - `outputs`: Node outputs (e.g., latents, images) + + The node spec also includes: + - `required_inputs` / `required_model_inputs`: Which params are required (marked with *) + - `block_name`: The modular pipeline block this node corresponds to on backend + + We provide factory methods for common parameters (e.g., `MellonParam.seed()`, `MellonParam.unet()`) so you don't + have to manually specify all the UI configuration. + + Args: + node_spec: Dict with `inputs`, `model_inputs`, `outputs` (lists of MellonParam), + plus `required_inputs`, `required_model_inputs`, `block_name`. + node_type: The node type string (e.g., "denoise", "controlnet") + + Returns: + Dict with: + - `params`: Flat dict of all params in Mellon UI format + - `input_names`: List of input parameter names + - `model_input_names`: List of model input parameter names + - `output_names`: List of output parameter names + - `block_name`: The backend block name + - `node_type`: The node type + + Example: + ```python + node_spec = { + "inputs": [MellonParam.seed(), MellonParam.prompt()], + "model_inputs": [MellonParam.unet()], + "outputs": [MellonParam.latents(display="output")], + "required_inputs": ["prompt"], + "required_model_inputs": ["unet"], + "block_name": "denoise", + } -@dataclass -class MellonNodeConfig(PushToHubMixin): + result = node_spec_to_mellon_dict(node_spec, "denoise") + # Returns: + # { + # "params": { + # "seed": {"label": "Seed", "type": "int", "default": 0}, + # "prompt": {"label": "Prompt *", "type": "string", "default": ""}, # * marks required + # "unet": {"label": "Denoise Model *", "type": "diffusers_auto_model", "display": "input"}, + # "latents": {"label": "Latents", "type": "latents", "display": "output"}, + # }, + # "input_names": ["seed", "prompt"], + # "model_input_names": ["unet"], + # "output_names": ["latents"], + # "block_name": "denoise", + # "node_type": "denoise", + # } + ``` """ - A MellonNodeConfig is a base class to build Mellon nodes UI with modular diffusers. + params = {} + input_names = [] + model_input_names = [] + output_names = [] + + required_inputs = node_spec.get("required_inputs", []) + required_model_inputs = node_spec.get("required_model_inputs", []) + + # Process inputs + for p in node_spec.get("inputs", []): + param_dict = p.to_dict() + if p.name in required_inputs: + param_dict["label"] = mark_required(param_dict["label"]) + params[p.name] = param_dict + input_names.append(p.name) + + # Process model_inputs + for p in node_spec.get("model_inputs", []): + param_dict = p.to_dict() + if p.name in required_model_inputs: + param_dict["label"] = mark_required(param_dict["label"]) + params[p.name] = param_dict + model_input_names.append(p.name) + + # Process outputs: add a prefix to the output name if it already exists as an input + for p in node_spec.get("outputs", []): + if p.name in input_names: + # rename to out_ + output_name = f"out_{p.name}" + else: + output_name = p.name + params[output_name] = p.to_dict() + output_names.append(output_name) - + return { + "params": params, + "input_names": input_names, + "model_input_names": model_input_names, + "output_names": output_names, + "block_name": node_spec.get("block_name"), + "node_type": node_type, + } - This is an experimental feature and is likely to change in the future. - +class MellonPipelineConfig: """ + Configuration for an entire Mellon pipeline containing multiple nodes. + + Accepts node specs as dicts with inputs/model_inputs/outputs lists of MellonParam, converts them to Mellon-ready + format, and handles save/load to Hub. + + Example: + ```python + config = MellonPipelineConfig( + node_specs={ + "denoise": { + "inputs": [MellonParam.seed(), MellonParam.prompt()], + "model_inputs": [MellonParam.unet()], + "outputs": [MellonParam.latents(display="output")], + "required_inputs": ["prompt"], + "required_model_inputs": ["unet"], + "block_name": "denoise", + }, + "decoder": { + "inputs": [MellonParam.latents(display="input")], + "outputs": [MellonParam.images()], + "block_name": "decoder", + }, + }, + label="My Pipeline", + default_repo="user/my-pipeline", + default_dtype="float16", + ) - inputs: List[Union[str, MellonParam]] - model_inputs: List[Union[str, MellonParam]] - outputs: List[Union[str, MellonParam]] - blocks_names: list[str] - node_type: str - config_name = "mellon_config.json" - - def __post_init__(self): - if isinstance(self.inputs, list): - self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS) - if isinstance(self.model_inputs, list): - self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS) - if isinstance(self.outputs, list): - self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS) - - @staticmethod - def _resolve_params_list( - params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]] - ) -> Dict[str, Dict[str, Any]]: - def _resolve_param( - param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]] - ) -> Tuple[str, Dict[str, Any]]: - if isinstance(param, str): - if param not in default_params_map: - raise ValueError(f"Unknown param '{param}', please define a `MellonParam` object instead") - return param, default_params_map[param].copy() - elif isinstance(param, MellonParam): - param_dict = param.to_dict() - param_name = param_dict.pop("name") - return param_name, param_dict + # Access Mellon format dict + denoise = config.node_params["denoise"] + input_names = denoise["input_names"] + params = denoise["params"] + + # Save to Hub + config.save("./my_config", push_to_hub=True, repo_id="user/my-pipeline") + + # Load from Hub + loaded = MellonPipelineConfig.load("user/my-pipeline") + ``` + """ + + config_name = "mellon_pipeline_config.json" + + def __init__( + self, + node_specs: Dict[str, Optional[Dict[str, Any]]], + label: str = "", + default_repo: str = "", + default_dtype: str = "", + ): + """ + Args: + node_specs: Dict mapping node_type to node spec or None. + Node spec has: inputs, model_inputs, outputs, required_inputs, required_model_inputs, + block_name (all optional) + label: Human-readable label for the pipeline + default_repo: Default HuggingFace repo for this pipeline + default_dtype: Default dtype (e.g., "float16", "bfloat16") + """ + # Convert all node specs to Mellon format immediately + self.node_specs = node_specs + + self.label = label + self.default_repo = default_repo + self.default_dtype = default_dtype + + @property + def node_params(self) -> Dict[str, Any]: + """Lazily compute node_params from node_specs.""" + if self.node_specs is None: + return self._node_params + + params = {} + for node_type, spec in self.node_specs.items(): + if spec is None: + params[node_type] = None else: - raise ValueError( - f"Unknown param type '{type(param)}', please use a string or a `MellonParam` object instead" - ) + params[node_type] = node_spec_to_mellon_dict(spec, node_type) + return params + + def __repr__(self) -> str: + lines = [ + f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})" + ] + for node_type, spec in self.node_specs.items(): + if spec is None: + lines.append(f" {node_type}: None") + else: + inputs = [p.name for p in spec.get("inputs", [])] + model_inputs = [p.name for p in spec.get("model_inputs", [])] + outputs = [p.name for p in spec.get("outputs", [])] + lines.append(f" {node_type}:") + lines.append(f" inputs: {inputs}") + lines.append(f" model_inputs: {model_inputs}") + lines.append(f" outputs: {outputs}") + return "\n".join(lines) - resolved = {} - for p in params: - logger.info(f" Resolving param: {p}") - name, cfg = _resolve_param(p, default_map) - if name in resolved: - raise ValueError(f"Duplicate param '{name}'") - resolved[name] = cfg - return resolved + def to_dict(self) -> Dict[str, Any]: + """Convert to a JSON-serializable dictionary.""" + return { + "label": self.label, + "default_repo": self.default_repo, + "default_dtype": self.default_dtype, + "node_params": self.node_params, + } @classmethod - @validate_hf_hub_args - def load_mellon_config( + def from_dict(cls, data: Dict[str, Any]) -> "MellonPipelineConfig": + """ + Create from a dictionary (loaded from JSON). + + Note: The mellon_params are already in Mellon format when loading from JSON. + """ + instance = cls.__new__(cls) + instance.node_specs = None + instance._node_params = data.get("node_params", {}) + instance.label = data.get("label", "") + instance.default_repo = data.get("default_repo", "") + instance.default_dtype = data.get("default_dtype", "") + return instance + + def to_json_string(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=False) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """Save to a JSON file.""" + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + @classmethod + def from_json_file(cls, json_file_path: Union[str, os.PathLike]) -> "MellonPipelineConfig": + """Load from a JSON file.""" + with open(json_file_path, "r", encoding="utf-8") as reader: + data = json.load(reader) + return cls.from_dict(data) + + def save(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """Save the mellon pipeline config to a directory.""" + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + output_path = os.path.join(save_directory, self.config_name) + self.to_json_file(output_path) + logger.info(f"Pipeline config saved to {output_path}") + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", None) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + upload_file( + path_or_fileobj=output_path, + path_in_repo=self.config_name, + repo_id=repo_id, + token=token, + commit_message=commit_message or "Upload MellonPipelineConfig", + create_pr=create_pr, + ) + logger.info(f"Pipeline config pushed to hub: {repo_id}") + + @classmethod + def load( cls, pretrained_model_name_or_path: Union[str, os.PathLike], - return_unused_kwargs=False, - return_commit_hash=False, **kwargs, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - r""" - Load a model or scheduler configuration. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with - [`~ConfigMixin.save_config`]. - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - return_unused_kwargs (`bool`, *optional*, defaults to `False): - Whether unused keyword arguments of the config are returned. - return_commit_hash (`bool`, *optional*, defaults to `False): - Whether the `commit_hash` of the loaded configuration are returned. - - Returns: - `dict`: - A dictionary of all the parameters stored in a JSON configuration file. - - """ + ) -> "MellonPipelineConfig": + """Load a pipeline config from a local path or Hugging Face Hub.""" cache_dir = kwargs.pop("cache_dir", None) local_dir = kwargs.pop("local_dir", None) local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") @@ -450,27 +865,18 @@ def load_mellon_config( token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if cls.config_name is None: - raise ValueError( - "`self.config_name` is not defined. Note that one should not load a config from " - "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" - ) if os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)): - # Load from a PyTorch checkpoint - config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) - else: - raise EnvironmentError( - f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}." - ) + config_file = os.path.join(pretrained_model_name_or_path, cls.config_name) + if not os.path.isfile(config_file): + raise EnvironmentError(f"No file named {cls.config_name} found in {pretrained_model_name_or_path}") else: try: - # Load from URL or cache if already cached config_file = hf_hub_download( pretrained_model_name_or_path, filename=cls.config_name, @@ -480,6 +886,7 @@ def load_mellon_config( local_files_only=local_files_only, token=token, revision=revision, + subfolder=subfolder, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, ) @@ -519,245 +926,170 @@ def load_mellon_config( f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) - try: - with open(config_file, "r", encoding="utf-8") as reader: - text = reader.read() - config_dict = json.loads(text) - commit_hash = extract_commit_hash(config_file) + try: + return cls.from_json_file(config_file) except (json.JSONDecodeError, UnicodeDecodeError): - raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.") - if not (return_unused_kwargs or return_commit_hash): - return config_dict - - outputs = (config_dict,) - - if return_unused_kwargs: - outputs += (kwargs,) - - if return_commit_hash: - outputs += (commit_hash,) - - return outputs - - def save_mellon_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + @classmethod + def from_blocks( + cls, + blocks, + template: Dict[str, Optional[Dict[str, Any]]] = None, + label: str = "", + default_repo: str = "", + default_dtype: str = "bfloat16", + ) -> "MellonPipelineConfig": """ - Save the Mellon node definition to a JSON file. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the configuration JSON file is saved (will be created if it does not exist). - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + Create MellonPipelineConfig by matching template against actual pipeline blocks. """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - - os.makedirs(save_directory, exist_ok=True) - - # If we save using the predefined names, we can load using `from_config` - output_config_file = os.path.join(save_directory, self.config_name) - - self.to_json_file(output_config_file) - logger.info(f"Mellon node definition saved in {output_config_file}") - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - private = kwargs.pop("private", None) - create_pr = kwargs.pop("create_pr", False) - token = kwargs.pop("token", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id - subfolder = kwargs.pop("subfolder", None) + if template is None: + template = DEFAULT_NODE_SPECS + + sub_block_map = dict(blocks.sub_blocks) + + def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]: + """Filter template spec params based on what the block actually supports.""" + block_input_names = set(block.input_names) + block_output_names = set(block.intermediate_output_names) + block_component_names = set(block.component_names) + + filtered_inputs = [ + p + for p in template_spec.get("inputs", []) + if p.required_block_params is None + or all(name in block_input_names for name in p.required_block_params) + ] + filtered_model_inputs = [ + p + for p in template_spec.get("model_inputs", []) + if p.required_block_params is None + or all(name in block_component_names for name in p.required_block_params) + ] + filtered_outputs = [ + p + for p in template_spec.get("outputs", []) + if p.required_block_params is None + or all(name in block_output_names for name in p.required_block_params) + ] + + filtered_input_names = {p.name for p in filtered_inputs} + filtered_model_input_names = {p.name for p in filtered_model_inputs} + + filtered_required_inputs = [ + r for r in template_spec.get("required_inputs", []) if r in filtered_input_names + ] + filtered_required_model_inputs = [ + r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names + ] + + return { + "inputs": filtered_inputs, + "model_inputs": filtered_model_inputs, + "outputs": filtered_outputs, + "required_inputs": filtered_required_inputs, + "required_model_inputs": filtered_required_model_inputs, + "block_name": template_spec.get("block_name"), + } + + # Build node specs + node_specs = {} + for node_type, template_spec in template.items(): + if template_spec is None: + node_specs[node_type] = None + continue + + block_name = template_spec.get("block_name") + if block_name is None or block_name not in sub_block_map: + node_specs[node_type] = None + continue + + node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name]) - self._upload_folder( - save_directory, - repo_id, - token=token, - commit_message=commit_message, - create_pr=create_pr, - subfolder=subfolder, - ) + return cls( + node_specs=node_specs, + label=label or getattr(blocks, "model_name", ""), + default_repo=default_repo, + default_dtype=default_dtype, + ) - def to_json_file(self, json_file_path: Union[str, os.PathLike]): + @classmethod + def from_custom_block( + cls, + block, + node_label: str = None, + input_types: Optional[Dict[str, str]] = None, + output_types: Optional[Dict[str, str]] = None, + ) -> "MellonPipelineConfig": """ - Save the Mellon schema dictionary to a JSON file. + Create a MellonPipelineConfig from a custom block. Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file to save a configuration instance's parameters. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) + block: A block instance with `inputs`, `outputs`, and `expected_components`/`component_names` properties. + Each InputParam/OutputParam should have metadata={"mellon": ""} where type is one of: image, + video, text, checkbox, number, slider, dropdown, model. If metadata is None, maps to "custom". + node_label: The display label for the node. Defaults to block class name with spaces. + input_types: + Optional dict mapping input param names to mellon types. Overrides the block's metadata if provided. + Example: {"prompt": "textbox", "image": "image"} + output_types: + Optional dict mapping output param names to mellon types. Overrides the block's metadata if provided. + Example: {"prompt": "text", "images": "image"} - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string of the Mellon schema dict. - - Args: Returns: - `str`: String containing all the attributes that make up this configuration instance in JSON format. + MellonPipelineConfig instance """ - - mellon_dict = self.to_mellon_dict() - return json.dumps(mellon_dict, indent=2, sort_keys=True) + "\n" - - def to_mellon_dict(self) -> Dict[str, Any]: - """Return a JSON-serializable dict focusing on the Mellon schema fields only. - - params is a single flat dict composed as: {**inputs, **model_inputs, **outputs}. - """ - # inputs/model_inputs/outputs are already normalized dicts - merged_params = {} - merged_params.update(self.inputs or {}) - merged_params.update(self.model_inputs or {}) - merged_params.update(self.outputs or {}) - - return { - "node_type": self.node_type, - "blocks_names": self.blocks_names, - "params": merged_params, + if node_label is None: + class_name = block.__class__.__name__ + node_label = "".join([" " + c if c.isupper() else c for c in class_name]).strip() + + if input_types is None: + input_types = {} + if output_types is None: + output_types = {} + + inputs = [] + model_inputs = [] + outputs = [] + + # Process block inputs + for input_param in block.inputs: + if input_param.name is None: + continue + if input_param.name in input_types: + input_param = copy.copy(input_param) + input_param.metadata = {"mellon": input_types[input_param.name]} + print(f" processing input: {input_param.name}, metadata: {input_param.metadata}") + inputs.append(input_param_to_mellon_param(input_param)) + + # Process block outputs + for output_param in block.outputs: + if output_param.name is None: + continue + if output_param.name in output_types: + output_param = copy.copy(output_param) + output_param.metadata = {"mellon": output_types[output_param.name]} + outputs.append(output_param_to_mellon_param(output_param)) + + # Process expected components (all map to model inputs) + component_names = block.component_names + for component_name in component_names: + model_inputs.append(MellonParam.Input.model(component_name)) + + # Always add doc output + outputs.append(MellonParam.doc()) + + node_spec = { + "inputs": inputs, + "model_inputs": model_inputs, + "outputs": outputs, + "required_inputs": [], + "required_model_inputs": [], + "block_name": "custom", } - @classmethod - def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig": - """Create a config from a Mellon schema dict produced by to_mellon_dict(). - - Splits the flat params dict back into inputs/model_inputs/outputs using the known key spaces from - MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. Unknown keys are treated as inputs by - default. - """ - flat_params = mellon_dict.get("params", {}) - - inputs: Dict[str, Any] = {} - model_inputs: Dict[str, Any] = {} - outputs: Dict[str, Any] = {} - - for param_name, param_dict in flat_params.items(): - if param_dict.get("display", "") == "output": - outputs[param_name] = param_dict - elif param_dict.get("type", "") in ("diffusers_auto_model", "diffusers_auto_models"): - model_inputs[param_name] = param_dict - else: - inputs[param_name] = param_dict - return cls( - inputs=inputs, - model_inputs=model_inputs, - outputs=outputs, - blocks_names=mellon_dict.get("blocks_names", []), - node_type=mellon_dict.get("node_type"), + node_specs={"custom": node_spec}, + label=node_label, ) - - # YiYi Notes: not used yet - @classmethod - def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig": - """ - Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type, - use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs. - """ - if node_type not in NODE_TYPE_PARAMS_MAP: - raise ValueError(f"Node type {node_type} not supported") - - blocks_names = list(blocks.sub_blocks.keys()) - - default_node_config = NODE_TYPE_PARAMS_MAP[node_type] - inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", []) - model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", []) - outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", []) - - for required_input_name in blocks.required_inputs: - if required_input_name not in inputs_list: - inputs_list.append( - MellonParam( - name=required_input_name, label=required_input_name, type=required_input_name, display="input" - ) - ) - - for component_spec in blocks.expected_components: - if component_spec.name not in model_inputs_list: - model_inputs_list.append( - MellonParam( - name=component_spec.name, - label=component_spec.name, - type="diffusers_auto_model", - display="input", - ) - ) - - return cls( - inputs=inputs_list, - model_inputs=model_inputs_list, - outputs=outputs_list, - blocks_names=blocks_names, - node_type=node_type, - ) - - -# Minimal modular registry for Mellon node configs -class ModularMellonNodeRegistry: - """Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig.""" - - def __init__(self): - self._registry = {} - self._initialized = False - - def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]): - if not self._initialized: - _initialize_registry(self) - self._registry[pipeline_cls] = node_params - - def get(self, pipeline_cls: type) -> MellonNodeConfig: - if not self._initialized: - _initialize_registry(self) - return self._registry.get(pipeline_cls, None) - - def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]: - if not self._initialized: - _initialize_registry(self) - return self._registry - - -def _register_preset_node_types( - pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry -): - """Register all node-type presets for a given pipeline class from a params map.""" - node_configs = {} - for node_type, spec in params_map.items(): - node_config = MellonNodeConfig( - inputs=spec.get("inputs", []), - model_inputs=spec.get("model_inputs", []), - outputs=spec.get("outputs", []), - blocks_names=spec.get("block_names", []), - node_type=node_type, - ) - node_configs[node_type] = node_config - registry.register(pipeline_cls, node_configs) - - -def _initialize_registry(registry: ModularMellonNodeRegistry): - """Initialize the registry and register all available pipeline configs.""" - print("Initializing registry") - - registry._initialized = True - - try: - from .qwenimage.modular_pipeline import QwenImageModularPipeline - from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP - - _register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry) - except Exception: - raise Exception("Failed to register QwenImageModularPipeline") - - try: - from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline - from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP - - _register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry) - except Exception: - raise Exception("Failed to register StableDiffusionXLModularPipeline") diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index a6336de71a52..a5695736581f 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -34,6 +34,7 @@ from ..utils.hub_utils import load_or_create_model_card, populate_model_card from .components_manager import ComponentsManager from .modular_pipeline_utils import ( + MODULAR_MODEL_CARD_TEMPLATE, ComponentSpec, ConfigSpec, InputParam, @@ -41,6 +42,7 @@ OutputParam, format_components, format_configs, + generate_modular_model_card_content, make_doc_string, ) @@ -58,9 +60,13 @@ ("wan", "WanModularPipeline"), ("flux", "FluxModularPipeline"), ("flux-kontext", "FluxKontextModularPipeline"), + ("flux2", "Flux2ModularPipeline"), + ("flux2-klein", "Flux2KleinModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), + ("qwenimage-layered", "QwenImageLayeredModularPipeline"), + ("z-image", "ZImageModularPipeline"), ] ) @@ -229,7 +235,7 @@ def format_value(v): class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): """ - Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, + Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks [`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks. @@ -499,15 +505,19 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> @property def input_names(self) -> List[str]: - return [input_param.name for input_param in self.inputs] + return [input_param.name for input_param in self.inputs if input_param.name is not None] @property def intermediate_output_names(self) -> List[str]: - return [output_param.name for output_param in self.intermediate_outputs] + return [output_param.name for output_param in self.intermediate_outputs if output_param.name is not None] @property def output_names(self) -> List[str]: - return [output_param.name for output_param in self.outputs] + return [output_param.name for output_param in self.outputs if output_param.name is not None] + + @property + def component_names(self) -> List[str]: + return [component.name for component in self.expected_components] @property def doc(self): @@ -521,9 +531,10 @@ def doc(self): ) -class AutoPipelineBlocks(ModularPipelineBlocks): +class ConditionalPipelineBlocks(ModularPipelineBlocks): """ - A Pipeline Blocks that automatically selects a block to run based on the inputs. + A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the + `select_block` method to define the logic for selecting the block. This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the library implements for all the pipeline blocks (such as loading or saving etc.) @@ -533,12 +544,13 @@ class AutoPipelineBlocks(ModularPipelineBlocks): Attributes: block_classes: List of block classes to be used block_names: List of prefixes for each block - block_trigger_inputs: List of input names that trigger specific blocks, with None for default + block_trigger_inputs: List of input names that select_block() uses to determine which block to run """ block_classes = [] block_names = [] block_trigger_inputs = [] + default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided def __init__(self): sub_blocks = InsertableDict() @@ -548,26 +560,15 @@ def __init__(self): else: sub_blocks[block_name] = block self.sub_blocks = sub_blocks - if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + if not (len(self.block_classes) == len(self.block_names)): raise ValueError( - f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same." ) - default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last - # the order of blocks matters here because the first block with matching trigger will be dispatched - # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] - # as long as mask is provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): + if self.default_block_name is not None and self.default_block_name not in self.block_names: raise ValueError( - f"In {self.__class__.__name__}, exactly one None must be specified as the last element " - "in block_trigger_inputs." + f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}" ) - # Map trigger inputs to block objects - self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values())) - self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys())) - self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs)) - @property def model_name(self): return next(iter(self.sub_blocks.values())).model_name @@ -596,8 +597,10 @@ def expected_configs(self): @property def required_inputs(self) -> List[str]: - if None not in self.block_trigger_inputs: + # no default block means this conditional block can be skipped entirely + if self.default_block_name is None: return [] + first_block = next(iter(self.sub_blocks.values())) required_by_all = set(getattr(first_block, "required_inputs", set())) @@ -608,7 +611,6 @@ def required_inputs(self) -> List[str]: return list(required_by_all) - # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] @@ -633,36 +635,9 @@ def outputs(self) -> List[str]: combined_outputs = self.combine_outputs(*named_outputs) return combined_outputs - @torch.no_grad() - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - # Find default block first (if any) - - block = self.trigger_to_block_map.get(None) - for input_name in self.block_trigger_inputs: - if input_name is not None and state.get(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break - - if block is None: - logger.info(f"skipping auto block: {self.__class__.__name__}") - return pipeline, state - - try: - logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}") - return block(pipeline, state) - except Exception as e: - error_msg = ( - f"\nError in block: {block.__class__.__name__}\n" - f"Error details: {str(e)}\n" - f"Traceback:\n{traceback.format_exc()}" - ) - logger.error(error_msg) - raise - - def _get_trigger_inputs(self): + def _get_trigger_inputs(self) -> set: """ - Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique - block_trigger_inputs values + Returns a set of all unique trigger input values found in this block and nested blocks. """ def fn_recursive_get_trigger(blocks): @@ -670,9 +645,8 @@ def fn_recursive_get_trigger(blocks): if blocks is not None: for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) + # Check if current block has block_trigger_inputs if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has sub_blocks, recursively check them @@ -682,15 +656,57 @@ def fn_recursive_get_trigger(blocks): return trigger_values - trigger_inputs = set(self.block_trigger_inputs) - trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks)) + # Start with this block's block_trigger_inputs + all_triggers = {t for t in self.block_trigger_inputs if t is not None} + # Add nested triggers + all_triggers.update(fn_recursive_get_trigger(self.sub_blocks)) - return trigger_inputs + return all_triggers @property def trigger_inputs(self): + """All trigger inputs including from nested blocks.""" return self._get_trigger_inputs() + def select_block(self, **kwargs) -> Optional[str]: + """ + Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic + for selecting the block. + + Args: + **kwargs: Trigger input names and their values from the state. + + Returns: + Optional[str]: The name of the block to run, or None to use default/skip. + """ + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.") + + @torch.no_grad() + def __call__(self, pipeline, state: PipelineState) -> PipelineState: + trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None} + block_name = self.select_block(**trigger_kwargs) + + if block_name is None: + block_name = self.default_block_name + + if block_name is None: + logger.info(f"skipping conditional block: {self.__class__.__name__}") + return pipeline, state + + block = self.sub_blocks[block_name] + + try: + logger.info(f"Running block: {block.__class__.__name__}") + return block(pipeline, state) + except Exception as e: + error_msg = ( + f"\nError in block: {block.__class__.__name__}\n" + f"Error details: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ) + logger.error(error_msg) + raise + def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ @@ -702,7 +718,7 @@ def __repr__(self): header += "\n" header += " " + "=" * 100 + "\n" header += " This pipeline contains blocks that are selected at runtime based on inputs.\n" - header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" + header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n" header += " " + "=" * 100 + "\n\n" # Format description with proper indentation @@ -723,31 +739,20 @@ def __repr__(self): expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - # Blocks section - moved to the end with simplified format + # Blocks section blocks_str = " Sub-Blocks:\n" for i, (name, block) in enumerate(self.sub_blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, "block_to_trigger_map"): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" + if name == self.default_block_name: + addtional_str = " [default]" else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + addtional_str = "" + blocks_str += f" • {name}{addtional_str} ({block.__class__.__name__})\n" # Add block description - desc_lines = block.description.split("\n") - indented_desc = desc_lines[0] - if len(desc_lines) > 1: - indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) + block_desc_lines = block.description.split("\n") + indented_desc = block_desc_lines[0] + if len(block_desc_lines) > 1: + indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" # Build the representation with conditional sections @@ -778,6 +783,35 @@ def doc(self): ) +class AutoPipelineBlocks(ConditionalPipelineBlocks): + """ + A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs. + """ + + def __init__(self): + super().__init__() + + if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): + raise ValueError( + f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + ) + + @property + def default_block_name(self) -> Optional[str]: + """Derive default_block_name from block_trigger_inputs (None entry).""" + if None in self.block_trigger_inputs: + idx = self.block_trigger_inputs.index(None) + return self.block_names[idx] + return None + + def select_block(self, **kwargs) -> Optional[str]: + """Select block based on which trigger input is present (not None).""" + for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names): + if trigger_input is not None and kwargs.get(trigger_input) is not None: + return block_name + return None + + class SequentialPipelineBlocks(ModularPipelineBlocks): """ A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in @@ -879,7 +913,8 @@ def _get_inputs(self): # Only add outputs if the block cannot be skipped should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: + if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None: + # ConditionalPipelineBlocks without default can be skipped should_add_outputs = False if should_add_outputs: @@ -942,8 +977,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: def _get_trigger_inputs(self): """ - Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique - block_trigger_inputs values + Returns a set of all unique trigger input values found in the blocks. """ def fn_recursive_get_trigger(blocks): @@ -951,9 +985,8 @@ def fn_recursive_get_trigger(blocks): if blocks is not None: for name, block in blocks.items(): - # Check if current block has trigger inputs(i.e. auto block) + # Check if current block has block_trigger_inputs (ConditionalPipelineBlocks) if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: - # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) # If block has sub_blocks, recursively check them @@ -969,82 +1002,84 @@ def fn_recursive_get_trigger(blocks): def trigger_inputs(self): return self._get_trigger_inputs() - def _traverse_trigger_blocks(self, trigger_inputs): - # Convert trigger_inputs to a set for easier manipulation - active_triggers = set(trigger_inputs) + def _traverse_trigger_blocks(self, active_inputs): + """ + Traverse blocks and select which ones would run given the active inputs. + + Args: + active_inputs: Dict of input names to values that are "present" + + Returns: + OrderedDict of block_name -> block that would execute + """ - def fn_recursive_traverse(block, block_name, active_triggers): + def fn_recursive_traverse(block, block_name, active_inputs): result_blocks = OrderedDict() - # sequential(include loopsequential) or PipelineBlock - if not hasattr(block, "block_trigger_inputs"): - if block.sub_blocks: - # sequential or LoopSequentialPipelineBlocks (keep traversing) - for sub_block_name, sub_block in block.sub_blocks.items(): - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers) - blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()} - result_blocks.update(blocks_to_update) + # ConditionalPipelineBlocks (includes AutoPipelineBlocks) + if isinstance(block, ConditionalPipelineBlocks): + trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs} + selected_block_name = block.select_block(**trigger_kwargs) + + if selected_block_name is None: + selected_block_name = block.default_block_name + + if selected_block_name is None: + return result_blocks + + selected_block = block.sub_blocks[selected_block_name] + + if selected_block.sub_blocks: + result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs)) else: - # PipelineBlock - result_blocks[block_name] = block - # Add this block's output names to active triggers if defined - if hasattr(block, "outputs"): - active_triggers.update(out.name for out in block.outputs) + result_blocks[block_name] = selected_block + if hasattr(selected_block, "outputs"): + for out in selected_block.outputs: + active_inputs[out.name] = True + return result_blocks - # auto + # SequentialPipelineBlocks or LoopSequentialPipelineBlocks + if block.sub_blocks: + for sub_block_name, sub_block in block.sub_blocks.items(): + blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs) + blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()} + result_blocks.update(blocks_to_update) else: - # Find first block_trigger_input that matches any value in our active_triggers - this_block = None - for trigger_input in block.block_trigger_inputs: - if trigger_input is not None and trigger_input in active_triggers: - this_block = block.trigger_to_block_map[trigger_input] - break - - # If no matches found, try to get the default (None) block - if this_block is None and None in block.block_trigger_inputs: - this_block = block.trigger_to_block_map[None] - - if this_block is not None: - # sequential/auto (keep traversing) - if this_block.sub_blocks: - result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) - else: - # PipelineBlock - result_blocks[block_name] = this_block - # Add this block's output names to active triggers if defined - # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute? - if hasattr(this_block, "outputs"): - active_triggers.update(out.name for out in this_block.outputs) + result_blocks[block_name] = block + if hasattr(block, "outputs"): + for out in block.outputs: + active_inputs[out.name] = True return result_blocks all_blocks = OrderedDict() for block_name, block in self.sub_blocks.items(): - blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) + blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs) all_blocks.update(blocks_to_update) return all_blocks - def get_execution_blocks(self, *trigger_inputs): - trigger_inputs_all = self.trigger_inputs + def get_execution_blocks(self, **kwargs): + """ + Get the blocks that would execute given the specified inputs. - if trigger_inputs is not None: - if not isinstance(trigger_inputs, (list, tuple, set)): - trigger_inputs = [trigger_inputs] - invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] - if invalid_inputs: - logger.warning( - f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" - ) - trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] + Args: + **kwargs: Input names and values. Only trigger inputs affect block selection. + Pass any inputs that would be non-None at runtime. - if trigger_inputs is None: - if None in trigger_inputs_all: - trigger_inputs = [None] - else: - trigger_inputs = [trigger_inputs_all[0]] - blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) + Returns: + SequentialPipelineBlocks containing only the blocks that would execute + + Example: + # Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask, + image=image) + + # Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat") + """ + # Filter out None values + active_inputs = {k: v for k, v in kwargs.items() if v is not None} + + blocks_triggered = self._traverse_trigger_blocks(active_inputs) return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) def __repr__(self): @@ -1061,7 +1096,7 @@ def __repr__(self): header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n" # Get first trigger input as example example_input = next(t for t in self.trigger_inputs if t is not None) - header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n" + header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n" header += " " + "=" * 100 + "\n\n" # Format description with proper indentation @@ -1085,22 +1120,8 @@ def __repr__(self): # Blocks section - moved to the end with simplified format blocks_str = " Sub-Blocks:\n" for i, (name, block) in enumerate(self.sub_blocks.items()): - # Get trigger input for this block - trigger = None - if hasattr(self, "block_to_trigger_map"): - trigger = self.block_to_trigger_map.get(name) - # Format the trigger info - if trigger is None: - trigger_str = "[default]" - elif isinstance(trigger, (list, tuple)): - trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]" - else: - trigger_str = f"[trigger: {trigger}]" - # For AutoPipelineBlocks, add bullet points - blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n" - else: - # For SequentialPipelineBlocks, show execution order - blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" + # show execution order + blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" # Add block description desc_lines = block.description.split("\n") @@ -1224,15 +1245,9 @@ def _get_inputs(self): if inp.name not in outputs and inp not in inputs: inputs.append(inp) - # Only add outputs if the block cannot be skipped - should_add_outputs = True - if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: - should_add_outputs = False - - if should_add_outputs: - # Add this block's outputs - block_intermediate_outputs = [out.name for out in block.intermediate_outputs] - outputs.update(block_intermediate_outputs) + # Add this block's outputs + block_intermediate_outputs = [out.name for out in block.intermediate_outputs] + outputs.update(block_intermediate_outputs) for input_param in inputs: if input_param.name in self.required_inputs: @@ -1289,6 +1304,14 @@ def __init__(self): sub_blocks[block_name] = block self.sub_blocks = sub_blocks + # Validate that sub_blocks are only leaf blocks + for block_name, block in self.sub_blocks.items(): + if block.sub_blocks: + raise ValueError( + f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). " + f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks." + ) + @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks": """ @@ -1523,10 +1546,8 @@ def __init__( if blocks is None: if modular_config_dict is not None: blocks_class_name = modular_config_dict.get("_blocks_class_name") - elif config_dict is not None: - blocks_class_name = self.get_default_blocks_name(config_dict) else: - blocks_class_name = None + blocks_class_name = self.get_default_blocks_name(config_dict) if blocks_class_name is not None: diffusers_module = importlib.import_module("diffusers") blocks_class = getattr(diffusers_module, blocks_class_name) @@ -1585,7 +1606,6 @@ def __init__( for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default self.register_to_config(**default_configs) - self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None) @property @@ -1624,7 +1644,10 @@ def _load_pipeline_config( return None, config_dict except EnvironmentError as e: - logger.debug(f" model_index.json not found in the repo: {e}") + raise EnvironmentError( + f"Failed to load config from '{pretrained_model_name_or_path}'. " + f"Could not find or load 'modular_model_index.json' or 'model_index.json'." + ) from e return None, None @@ -1732,9 +1755,19 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + # Generate modular pipeline card content + card_content = generate_modular_model_card_content(self.blocks) + # Create a new empty model card and eventually tag it - model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True) - model_card = populate_model_card(model_card) + model_card = load_or_create_model_card( + repo_id, + token=token, + is_pipeline=True, + model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content), + is_modular=True, + ) + model_card = populate_model_card(model_card, tags=card_content["tags"]) + model_card.save(os.path.join(save_directory, "README.md")) # YiYi TODO: maybe order the json file to make it more readable: configs first, then components @@ -2122,6 +2155,8 @@ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwarg name for name in self._component_specs.keys() if self._component_specs[name].default_creation_method == "from_pretrained" + and self._component_specs[name].pretrained_model_name_or_path is not None + and getattr(self, name, None) is None ] elif isinstance(names, str): names = [names] @@ -2549,7 +2584,11 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = kwargs_type = expected_input_param.kwargs_type if name in passed_kwargs: state.set(name, passed_kwargs.pop(name), kwargs_type) - elif name not in state.values: + elif kwargs_type is not None and kwargs_type in passed_kwargs: + kwargs_dict = passed_kwargs.pop(kwargs_type) + for k, v in kwargs_dict.items(): + state.set(k, v, kwargs_type) + elif name is not None and name not in state.values: state.set(name, default, kwargs_type) # Warn about unexpected inputs diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index aa421a53727b..9e11fb7ef79b 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -15,14 +15,15 @@ import inspect import re from collections import OrderedDict -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional, Type, Union +import PIL.Image import torch from ..configuration_utils import ConfigMixin, FrozenDict from ..loaders.single_file_utils import _is_single_file_path_or_url -from ..utils import is_torch_available, logging +from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging if is_torch_available(): @@ -30,6 +31,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Template for modular pipeline model card description with placeholders +MODULAR_MODEL_CARD_TEMPLATE = """{model_description} + +## Example Usage + +[TODO] + +## Pipeline Architecture + +This modular pipeline is composed of the following blocks: + +{blocks_description} {trigger_inputs_section} + +## Model Components + +{components_description} {configs_section} + +## Input/Output Specification + +### Inputs {inputs_description} + +### Outputs {outputs_description} +""" + class InsertableDict(OrderedDict): def insert(self, key, value, index): @@ -185,7 +210,7 @@ def loading_fields(cls) -> List[str]: """ Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True). """ - return [f.name for f in fields(cls) if f.metadata.get("loading", False)] + return DIFFUSERS_LOAD_ID_FIELDS.copy() @property def load_id(self) -> str: @@ -197,7 +222,7 @@ def load_id(self) -> str: return "null" parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] - return "|".join(p for p in parts if p) + return "|".join(parts) @classmethod def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]: @@ -323,11 +348,192 @@ class ConfigSpec: description: Optional[str] = None -# YiYi Notes: both inputs and intermediate_inputs are InputParam objects -# however some fields are not relevant for intermediate_inputs -# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed -# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs -# -> should we use different class for inputs and intermediate_inputs? +# ====================================================== +# InputParam and OutputParam templates +# ====================================================== + +INPUT_PARAM_TEMPLATES = { + "prompt": { + "type_hint": str, + "required": True, + "description": "The prompt or prompts to guide image generation.", + }, + "negative_prompt": { + "type_hint": str, + "description": "The prompt or prompts not to guide the image generation.", + }, + "max_sequence_length": { + "type_hint": int, + "default": 512, + "description": "Maximum sequence length for prompt encoding.", + }, + "height": { + "type_hint": int, + "description": "The height in pixels of the generated image.", + }, + "width": { + "type_hint": int, + "description": "The width in pixels of the generated image.", + }, + "num_inference_steps": { + "type_hint": int, + "default": 50, + "description": "The number of denoising steps.", + }, + "num_images_per_prompt": { + "type_hint": int, + "default": 1, + "description": "The number of images to generate per prompt.", + }, + "generator": { + "type_hint": torch.Generator, + "description": "Torch generator for deterministic generation.", + }, + "sigmas": { + "type_hint": List[float], + "description": "Custom sigmas for the denoising process.", + }, + "strength": { + "type_hint": float, + "default": 0.9, + "description": "Strength for img2img/inpainting.", + }, + "image": { + "type_hint": Union[PIL.Image.Image, List[PIL.Image.Image]], + "required": True, + "description": "Reference image(s) for denoising. Can be a single image or list of images.", + }, + "latents": { + "type_hint": torch.Tensor, + "description": "Pre-generated noisy latents for image generation.", + }, + "timesteps": { + "type_hint": torch.Tensor, + "description": "Timesteps for the denoising process.", + }, + "output_type": { + "type_hint": str, + "default": "pil", + "description": "Output format: 'pil', 'np', 'pt'.", + }, + "attention_kwargs": { + "type_hint": Dict[str, Any], + "description": "Additional kwargs for attention processors.", + }, + "denoiser_input_fields": { + "name": None, + "kwargs_type": "denoiser_input_fields", + "description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + }, + # inpainting + "mask_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Mask image for inpainting.", + }, + "padding_mask_crop": { + "type_hint": int, + "description": "Padding for mask cropping in inpainting.", + }, + # controlnet + "control_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Control image for ControlNet conditioning.", + }, + "control_guidance_start": { + "type_hint": float, + "default": 0.0, + "description": "When to start applying ControlNet.", + }, + "control_guidance_end": { + "type_hint": float, + "default": 1.0, + "description": "When to stop applying ControlNet.", + }, + "controlnet_conditioning_scale": { + "type_hint": float, + "default": 1.0, + "description": "Scale for ControlNet conditioning.", + }, + "layers": { + "type_hint": int, + "default": 4, + "description": "Number of layers to extract from the image", + }, + # common intermediate inputs + "prompt_embeds": { + "type_hint": torch.Tensor, + "required": True, + "description": "text embeddings used to guide the image generation. Can be generated from text_encoder step.", + }, + "prompt_embeds_mask": { + "type_hint": torch.Tensor, + "required": True, + "description": "mask for the text embeddings. Can be generated from text_encoder step.", + }, + "negative_prompt_embeds": { + "type_hint": torch.Tensor, + "description": "negative text embeddings used to guide the image generation. Can be generated from text_encoder step.", + }, + "negative_prompt_embeds_mask": { + "type_hint": torch.Tensor, + "description": "mask for the negative text embeddings. Can be generated from text_encoder step.", + }, + "image_latents": { + "type_hint": torch.Tensor, + "required": True, + "description": "image latents used to guide the image generation. Can be generated from vae_encoder step.", + }, + "batch_size": { + "type_hint": int, + "default": 1, + "description": "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + }, + "dtype": { + "type_hint": torch.dtype, + "default": torch.float32, + "description": "The dtype of the model inputs, can be generated in input step.", + }, +} + +OUTPUT_PARAM_TEMPLATES = { + "images": { + "type_hint": List[PIL.Image.Image], + "description": "Generated images.", + }, + "latents": { + "type_hint": torch.Tensor, + "description": "Denoised latents.", + }, + # intermediate outputs + "prompt_embeds": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The prompt embeddings.", + }, + "prompt_embeds_mask": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The encoder attention mask.", + }, + "negative_prompt_embeds": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The negative prompt embeddings.", + }, + "negative_prompt_embeds_mask": { + "type_hint": torch.Tensor, + "kwargs_type": "denoiser_input_fields", + "description": "The negative prompt embeddings mask.", + }, + "image_latents": { + "type_hint": torch.Tensor, + "description": "The latent representation of the input image.", + }, +} + + @dataclass class InputParam: """Specification for an input parameter.""" @@ -337,11 +543,32 @@ class InputParam: default: Any = None required: bool = False description: str = "" - kwargs_type: str = None # YiYi Notes: remove this feature (maybe) + kwargs_type: str = None + metadata: Dict[str, Any] = None def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + @classmethod + def template(cls, template_name: str, note: str = None, **overrides) -> "InputParam": + """Get template for name if exists, otherwise raise ValueError.""" + if template_name not in INPUT_PARAM_TEMPLATES: + raise ValueError(f"InputParam template for {template_name} not found") + + template_kwargs = INPUT_PARAM_TEMPLATES[template_name].copy() + + # Determine the actual param name: + # 1. From overrides if provided + # 2. From template if present + # 3. Fall back to template_name + name = overrides.pop("name", template_kwargs.pop("name", template_name)) + + if note and "description" in template_kwargs: + template_kwargs["description"] = f"{template_kwargs['description']} ({note})" + + template_kwargs.update(overrides) + return cls(name=name, **template_kwargs) + @dataclass class OutputParam: @@ -350,13 +577,34 @@ class OutputParam: name: str type_hint: Any = None description: str = "" - kwargs_type: str = None # YiYi notes: remove this feature (maybe) + kwargs_type: str = None + metadata: Dict[str, Any] = None def __repr__(self): return ( f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" ) + @classmethod + def template(cls, template_name: str, note: str = None, **overrides) -> "OutputParam": + """Get template for name if exists, otherwise raise ValueError.""" + if template_name not in OUTPUT_PARAM_TEMPLATES: + raise ValueError(f"OutputParam template for {template_name} not found") + + template_kwargs = OUTPUT_PARAM_TEMPLATES[template_name].copy() + + # Determine the actual param name: + # 1. From overrides if provided + # 2. From template if present + # 3. Fall back to template_name + name = overrides.pop("name", template_kwargs.pop("name", template_name)) + + if note and "description" in template_kwargs: + template_kwargs["description"] = f"{template_kwargs['description']} ({note})" + + template_kwargs.update(overrides) + return cls(name=name, **template_kwargs) + def format_inputs_short(inputs): """ @@ -509,10 +757,12 @@ def wrap_text(text, indent, max_length): desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" + else: + param_str += f"\n{desc_indent}TODO: Add description." formatted_params.append(param_str) - return "\n\n".join(formatted_params) + return "\n".join(formatted_params) def format_input_params(input_params, indent_level=4, max_line_length=115): @@ -582,7 +832,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty loading_field_values = [] for field_name in component.loading_fields(): field_value = getattr(component, field_name) - if field_value is not None: + if field_value: loading_field_values.append(f"{field_name}={field_value}") # Add loading field information if available @@ -669,17 +919,17 @@ def make_doc_string( # Add description if description: desc_lines = description.strip().split("\n") - aligned_desc = "\n".join(" " + line for line in desc_lines) + aligned_desc = "\n".join(" " + line.rstrip() for line in desc_lines) output += aligned_desc + "\n\n" # Add components section if provided if expected_components and len(expected_components) > 0: - components_str = format_components(expected_components, indent_level=2) + components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) output += components_str + "\n\n" # Add configs section if provided if expected_configs and len(expected_configs) > 0: - configs_str = format_configs(expected_configs, indent_level=2) + configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) output += configs_str + "\n\n" # Add inputs section @@ -690,3 +940,178 @@ def make_doc_string( output += format_output_params(outputs, indent_level=2) return output + + +def generate_modular_model_card_content(blocks) -> Dict[str, Any]: + """ + Generate model card content for a modular pipeline. + + This function creates a comprehensive model card with descriptions of the pipeline's architecture, components, + configurations, inputs, and outputs. + + Args: + blocks: The pipeline's blocks object containing all pipeline specifications + + Returns: + Dict[str, Any]: A dictionary containing formatted content sections: + - pipeline_name: Name of the pipeline + - model_description: Overall description with pipeline type + - blocks_description: Detailed architecture of blocks + - components_description: List of required components + - configs_section: Configuration parameters section + - inputs_description: Input parameters specification + - outputs_description: Output parameters specification + - trigger_inputs_section: Conditional execution information + - tags: List of relevant tags for the model card + """ + blocks_class_name = blocks.__class__.__name__ + pipeline_name = blocks_class_name.replace("Blocks", " Pipeline") + description = getattr(blocks, "description", "A modular diffusion pipeline.") + + # generate blocks architecture description + blocks_desc_parts = [] + sub_blocks = getattr(blocks, "sub_blocks", None) or {} + if sub_blocks: + for i, (name, block) in enumerate(sub_blocks.items()): + block_class = block.__class__.__name__ + block_desc = block.description.split("\n")[0] if getattr(block, "description", "") else "" + blocks_desc_parts.append(f"{i + 1}. **{name}** (`{block_class}`)") + if block_desc: + blocks_desc_parts.append(f" - {block_desc}") + + # add sub-blocks if any + if hasattr(block, "sub_blocks") and block.sub_blocks: + for sub_name, sub_block in block.sub_blocks.items(): + sub_class = sub_block.__class__.__name__ + sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else "" + blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`") + if sub_desc: + blocks_desc_parts.append(f" - {sub_desc}") + + blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined." + + components = getattr(blocks, "expected_components", []) + if components: + components_str = format_components(components, indent_level=0, add_empty_lines=False) + # remove the "Components:" header since template has its own + components_description = components_str.replace("Components:\n", "").strip() + if components_description: + # Convert to enumerated list + lines = [line.strip() for line in components_description.split("\n") if line.strip()] + enumerated_lines = [f"{i + 1}. {line}" for i, line in enumerate(lines)] + components_description = "\n".join(enumerated_lines) + else: + components_description = "No specific components required." + else: + components_description = "No specific components required. Components can be loaded dynamically." + + configs = getattr(blocks, "expected_configs", []) + configs_section = "" + if configs: + configs_str = format_configs(configs, indent_level=0, add_empty_lines=False) + configs_description = configs_str.replace("Configs:\n", "").strip() + if configs_description: + configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}" + + inputs = blocks.inputs + outputs = blocks.outputs + + # format inputs as markdown list + inputs_parts = [] + required_inputs = [inp for inp in inputs if inp.required] + optional_inputs = [inp for inp in inputs if not inp.required] + + if required_inputs: + inputs_parts.append("**Required:**\n") + for inp in required_inputs: + if hasattr(inp.type_hint, "__name__"): + type_str = inp.type_hint.__name__ + elif inp.type_hint is not None: + type_str = str(inp.type_hint).replace("typing.", "") + else: + type_str = "Any" + desc = inp.description or "No description provided" + inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}") + + if optional_inputs: + if required_inputs: + inputs_parts.append("") + inputs_parts.append("**Optional:**\n") + for inp in optional_inputs: + if hasattr(inp.type_hint, "__name__"): + type_str = inp.type_hint.__name__ + elif inp.type_hint is not None: + type_str = str(inp.type_hint).replace("typing.", "") + else: + type_str = "Any" + desc = inp.description or "No description provided" + default_str = f", default: `{inp.default}`" if inp.default is not None else "" + inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}") + + inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined." + + # format outputs as markdown list + outputs_parts = [] + for out in outputs: + if hasattr(out.type_hint, "__name__"): + type_str = out.type_hint.__name__ + elif out.type_hint is not None: + type_str = str(out.type_hint).replace("typing.", "") + else: + type_str = "Any" + desc = out.description or "No description provided" + outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}") + + outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs." + + trigger_inputs_section = "" + if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: + trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None]) + if trigger_inputs_list: + trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list) + trigger_inputs_section = f""" +### Conditional Execution + +This pipeline contains blocks that are selected at runtime based on inputs: +- **Trigger Inputs**: {trigger_inputs_str} +""" + + # generate tags based on pipeline characteristics + tags = ["modular-diffusers", "diffusers"] + + if hasattr(blocks, "model_name") and blocks.model_name: + tags.append(blocks.model_name) + + if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs: + triggers = blocks.trigger_inputs + if any(t in triggers for t in ["mask", "mask_image"]): + tags.append("inpainting") + if any(t in triggers for t in ["image", "image_latents"]): + tags.append("image-to-image") + if any(t in triggers for t in ["control_image", "controlnet_cond"]): + tags.append("controlnet") + if not any(t in triggers for t in ["image", "mask", "image_latents", "mask_image"]): + tags.append("text-to-image") + else: + tags.append("text-to-image") + + block_count = len(blocks.sub_blocks) + model_description = f"""This is a modular diffusion pipeline built with 🧨 Diffusers' modular pipeline framework. + +**Pipeline Type**: {blocks_class_name} + +**Description**: {description} + +This pipeline uses a {block_count}-block architecture that can be customized and extended.""" + + return { + "pipeline_name": pipeline_name, + "model_description": model_description, + "blocks_description": blocks_description, + "components_description": components_description, + "configs_section": configs_section, + "inputs_description": inputs_description, + "outputs_description": outputs_description, + "trigger_inputs_section": trigger_inputs_section, + "tags": tags, + } diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py deleted file mode 100644 index f7ee1dd3097b..000000000000 --- a/src/diffusers/modular_pipelines/node_utils.py +++ /dev/null @@ -1,661 +0,0 @@ -import json -import logging -import os -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import numpy as np -import PIL -import torch - -from ..configuration_utils import ConfigMixin -from ..image_processor import PipelineImageInput -from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks -from .modular_pipeline_utils import InputParam - - -logger = logging.getLogger(__name__) - -# YiYi Notes: this is actually for SDXL, put it here for now -SDXL_INPUTS_SCHEMA = { - "prompt": InputParam( - "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation" - ), - "prompt_2": InputParam( - "prompt_2", - type_hint=Union[str, List[str]], - description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", - ), - "negative_prompt": InputParam( - "negative_prompt", - type_hint=Union[str, List[str]], - description="The prompt or prompts not to guide the image generation", - ), - "negative_prompt_2": InputParam( - "negative_prompt_2", - type_hint=Union[str, List[str]], - description="The negative prompt or prompts for text_encoder_2", - ), - "cross_attention_kwargs": InputParam( - "cross_attention_kwargs", - type_hint=Optional[dict], - description="Kwargs dictionary passed to the AttentionProcessor", - ), - "clip_skip": InputParam( - "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder" - ), - "image": InputParam( - "image", - type_hint=PipelineImageInput, - required=True, - description="The image(s) to modify for img2img or inpainting", - ), - "mask_image": InputParam( - "mask_image", - type_hint=PipelineImageInput, - required=True, - description="Mask image for inpainting, white pixels will be repainted", - ), - "generator": InputParam( - "generator", - type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], - description="Generator(s) for deterministic generation", - ), - "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), - "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), - "num_images_per_prompt": InputParam( - "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" - ), - "num_inference_steps": InputParam( - "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" - ), - "timesteps": InputParam( - "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process" - ), - "sigmas": InputParam( - "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process" - ), - "denoising_end": InputParam( - "denoising_end", - type_hint=Optional[float], - description="Fraction of denoising process to complete before termination", - ), - # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999 - "strength": InputParam( - "strength", type_hint=float, default=0.3, description="How much to transform the reference image" - ), - "denoising_start": InputParam( - "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process" - ), - "latents": InputParam( - "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation" - ), - "padding_mask_crop": InputParam( - "padding_mask_crop", - type_hint=Optional[Tuple[int, int]], - description="Size of margin in crop for image and mask", - ), - "original_size": InputParam( - "original_size", - type_hint=Optional[Tuple[int, int]], - description="Original size of the image for SDXL's micro-conditioning", - ), - "target_size": InputParam( - "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning" - ), - "negative_original_size": InputParam( - "negative_original_size", - type_hint=Optional[Tuple[int, int]], - description="Negative conditioning based on image resolution", - ), - "negative_target_size": InputParam( - "negative_target_size", - type_hint=Optional[Tuple[int, int]], - description="Negative conditioning based on target resolution", - ), - "crops_coords_top_left": InputParam( - "crops_coords_top_left", - type_hint=Tuple[int, int], - default=(0, 0), - description="Top-left coordinates for SDXL's micro-conditioning", - ), - "negative_crops_coords_top_left": InputParam( - "negative_crops_coords_top_left", - type_hint=Tuple[int, int], - default=(0, 0), - description="Negative conditioning crop coordinates", - ), - "aesthetic_score": InputParam( - "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" - ), - "negative_aesthetic_score": InputParam( - "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" - ), - "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), - "output_type": InputParam( - "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" - ), - "ip_adapter_image": InputParam( - "ip_adapter_image", - type_hint=PipelineImageInput, - required=True, - description="Image(s) to be used as IP adapter", - ), - "control_image": InputParam( - "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" - ), - "control_guidance_start": InputParam( - "control_guidance_start", - type_hint=Union[float, List[float]], - default=0.0, - description="When ControlNet starts applying", - ), - "control_guidance_end": InputParam( - "control_guidance_end", - type_hint=Union[float, List[float]], - default=1.0, - description="When ControlNet stops applying", - ), - "controlnet_conditioning_scale": InputParam( - "controlnet_conditioning_scale", - type_hint=Union[float, List[float]], - default=1.0, - description="Scale factor for ControlNet outputs", - ), - "guess_mode": InputParam( - "guess_mode", - type_hint=bool, - default=False, - description="Enables ControlNet encoder to recognize input without prompts", - ), - "control_mode": InputParam( - "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" - ), -} - -SDXL_INTERMEDIATE_INPUTS_SCHEMA = { - "prompt_embeds": InputParam( - "prompt_embeds", - type_hint=torch.Tensor, - required=True, - description="Text embeddings used to guide image generation", - ), - "negative_prompt_embeds": InputParam( - "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" - ), - "pooled_prompt_embeds": InputParam( - "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" - ), - "negative_pooled_prompt_embeds": InputParam( - "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" - ), - "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), - "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - "preprocess_kwargs": InputParam( - "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" - ), - "latents": InputParam( - "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" - ), - "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam( - "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" - ), - "latent_timestep": InputParam( - "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" - ), - "image_latents": InputParam( - "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" - ), - "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), - "masked_image_latents": InputParam( - "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" - ), - "add_time_ids": InputParam( - "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" - ), - "negative_add_time_ids": InputParam( - "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" - ), - "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), - "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), - "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), - "ip_adapter_embeds": InputParam( - "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" - ), - "negative_ip_adapter_embeds": InputParam( - "negative_ip_adapter_embeds", - type_hint=List[torch.Tensor], - description="Negative image embeddings for IP-Adapter", - ), - "images": InputParam( - "images", - type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], - required=True, - description="Generated images", - ), -} - -SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} - - -DEFAULT_PARAM_MAPS = { - "prompt": { - "label": "Prompt", - "type": "string", - "default": "a bear sitting in a chair drinking a milkshake", - "display": "textarea", - }, - "negative_prompt": { - "label": "Negative Prompt", - "type": "string", - "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", - "display": "textarea", - }, - "num_inference_steps": { - "label": "Steps", - "type": "int", - "default": 25, - "min": 1, - "max": 1000, - }, - "seed": { - "label": "Seed", - "type": "int", - "default": 0, - "min": 0, - "display": "random", - }, - "width": { - "label": "Width", - "type": "int", - "display": "text", - "default": 1024, - "min": 8, - "max": 8192, - "step": 8, - "group": "dimensions", - }, - "height": { - "label": "Height", - "type": "int", - "display": "text", - "default": 1024, - "min": 8, - "max": 8192, - "step": 8, - "group": "dimensions", - }, - "images": { - "label": "Images", - "type": "image", - "display": "output", - }, - "image": { - "label": "Image", - "type": "image", - "display": "input", - }, -} - -DEFAULT_TYPE_MAPS = { - "int": { - "type": "int", - "default": 0, - "min": 0, - }, - "float": { - "type": "float", - "default": 0.0, - "min": 0.0, - }, - "str": { - "type": "string", - "default": "", - }, - "bool": { - "type": "boolean", - "default": False, - }, - "image": { - "type": "image", - }, -} - -DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"] -DEFAULT_CATEGORY = "Modular Diffusers" -DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"] -DEFAULT_PARAMS_GROUPS_KEYS = { - "text_encoders": ["text_encoder", "tokenizer"], - "ip_adapter_embeds": ["ip_adapter_embeds"], - "prompt_embeddings": ["prompt_embeds"], -} - - -def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): - """ - Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" -> - "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None - """ - if name is None: - return None - for group_name, group_keys in group_params_keys.items(): - for group_key in group_keys: - if group_key in name: - return group_name - return None - - -class ModularNode(ConfigMixin): - """ - A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper - around a ModularPipelineBlocks object. - - > [!WARNING] > This is an experimental feature and is likely to change in the future. - """ - - config_name = "node_config.json" - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: str, - trust_remote_code: Optional[bool] = None, - **kwargs, - ): - blocks = ModularPipelineBlocks.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs - ) - return cls(blocks, **kwargs) - - def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): - self.blocks = blocks - - if label is None: - label = self.blocks.__class__.__name__ - # blocks param name -> mellon param name - self.name_mapping = {} - - input_params = {} - # pass or create a default param dict for each input - # e.g. for prompt, - # prompt = { - # "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers - # "label": "Prompt", - # "type": "string", - # "default": "a bear sitting in a chair drinking a milkshake", - # "display": "textarea"} - # if type is not specified, it'll be a "custom" param of its own type - # e.g. you can pass ModularNode(scheduler = {name :"scheduler"}) - # it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}} - # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}} - inputs = self.blocks.inputs + self.blocks.intermediate_inputs - for inp in inputs: - param = kwargs.pop(inp.name, None) - if param: - # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...}) - input_params[inp.name] = param - mellon_name = param.pop("name", inp.name) - if mellon_name != inp.name: - self.name_mapping[inp.name] = mellon_name - continue - - if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): - continue - - if inp.name in DEFAULT_PARAM_MAPS: - # first check if it's in the default param map, if so, directly use that - param = DEFAULT_PARAM_MAPS[inp.name].copy() - elif get_group_name(inp.name): - param = get_group_name(inp.name) - if inp.name not in self.name_mapping: - self.name_mapping[inp.name] = param - else: - # if not, check if it's in the SDXL input schema, if so, - # 1. use the type hint to determine the type - # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}} - if inp.type_hint is not None: - type_str = str(inp.type_hint).lower() - else: - inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) - type_str = str(inp_spec.type_hint).lower() if inp_spec else "" - for type_key, type_param in DEFAULT_TYPE_MAPS.items(): - if type_key in type_str: - param = type_param.copy() - param["label"] = inp.name - param["display"] = "input" - break - else: - param = inp.name - # add the param dict to the inp_params dict - input_params[inp.name] = param - - component_params = {} - for comp in self.blocks.expected_components: - param = kwargs.pop(comp.name, None) - if param: - component_params[comp.name] = param - mellon_name = param.pop("name", comp.name) - if mellon_name != comp.name: - self.name_mapping[comp.name] = mellon_name - continue - - to_exclude = False - for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: - if exclude_key in comp.name: - to_exclude = True - break - if to_exclude: - continue - - if get_group_name(comp.name): - param = get_group_name(comp.name) - if comp.name not in self.name_mapping: - self.name_mapping[comp.name] = param - elif comp.name in DEFAULT_MODEL_KEYS: - param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} - else: - param = comp.name - # add the param dict to the model_params dict - component_params[comp.name] = param - - output_params = {} - if isinstance(self.blocks, SequentialPipelineBlocks): - last_block_name = list(self.blocks.sub_blocks.keys())[-1] - outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs - else: - outputs = self.blocks.intermediate_outputs - - for out in outputs: - param = kwargs.pop(out.name, None) - if param: - output_params[out.name] = param - mellon_name = param.pop("name", out.name) - if mellon_name != out.name: - self.name_mapping[out.name] = mellon_name - continue - - if out.name in DEFAULT_PARAM_MAPS: - param = DEFAULT_PARAM_MAPS[out.name].copy() - param["display"] = "output" - else: - group_name = get_group_name(out.name) - if group_name: - param = group_name - if out.name not in self.name_mapping: - self.name_mapping[out.name] = param - else: - param = out.name - # add the param dict to the outputs dict - output_params[out.name] = param - - if len(kwargs) > 0: - logger.warning(f"Unused kwargs: {kwargs}") - - register_dict = { - "category": category, - "label": label, - "input_params": input_params, - "component_params": component_params, - "output_params": output_params, - "name_mapping": self.name_mapping, - } - self.register_to_config(**register_dict) - - def setup(self, components_manager, collection=None): - self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection) - self._components_manager = components_manager - - @property - def mellon_config(self): - return self._convert_to_mellon_config() - - def _convert_to_mellon_config(self): - node = {} - node["label"] = self.config.label - node["category"] = self.config.category - - node_param = {} - for inp_name, inp_param in self.config.input_params.items(): - if inp_name in self.name_mapping: - mellon_name = self.name_mapping[inp_name] - else: - mellon_name = inp_name - if isinstance(inp_param, str): - param = { - "label": inp_param, - "type": inp_param, - "display": "input", - } - else: - param = inp_param - - if mellon_name not in node_param: - node_param[mellon_name] = param - else: - logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") - - for comp_name, comp_param in self.config.component_params.items(): - if comp_name in self.name_mapping: - mellon_name = self.name_mapping[comp_name] - else: - mellon_name = comp_name - if isinstance(comp_param, str): - param = { - "label": comp_param, - "type": comp_param, - "display": "input", - } - else: - param = comp_param - - if mellon_name not in node_param: - node_param[mellon_name] = param - else: - logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") - - for out_name, out_param in self.config.output_params.items(): - if out_name in self.name_mapping: - mellon_name = self.name_mapping[out_name] - else: - mellon_name = out_name - if isinstance(out_param, str): - param = { - "label": out_param, - "type": out_param, - "display": "output", - } - else: - param = out_param - - if mellon_name not in node_param: - node_param[mellon_name] = param - else: - logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") - node["params"] = node_param - return node - - def save_mellon_config(self, file_path): - """ - Save the Mellon configuration to a JSON file. - - Args: - file_path (str or Path): Path where the JSON file will be saved - - Returns: - Path: Path to the saved config file - """ - file_path = Path(file_path) - - # Create directory if it doesn't exist - os.makedirs(file_path.parent, exist_ok=True) - - # Create a combined dictionary with module definition and name mapping - config = {"module": self.mellon_config, "name_mapping": self.name_mapping} - - # Save the config to file - with open(file_path, "w", encoding="utf-8") as f: - json.dump(config, f, indent=2) - - logger.info(f"Mellon config and name mapping saved to {file_path}") - - return file_path - - @classmethod - def load_mellon_config(cls, file_path): - """ - Load a Mellon configuration from a JSON file. - - Args: - file_path (str or Path): Path to the JSON file containing Mellon config - - Returns: - dict: The loaded combined configuration containing 'module' and 'name_mapping' - """ - file_path = Path(file_path) - - if not file_path.exists(): - raise FileNotFoundError(f"Config file not found: {file_path}") - - with open(file_path, "r", encoding="utf-8") as f: - config = json.load(f) - - logger.info(f"Mellon config loaded from {file_path}") - - return config - - def process_inputs(self, **kwargs): - params_components = {} - for comp_name, comp_param in self.config.component_params.items(): - logger.debug(f"component: {comp_name}") - mellon_comp_name = self.name_mapping.get(comp_name, comp_name) - if mellon_comp_name in kwargs: - if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: - comp = kwargs[mellon_comp_name].pop(comp_name) - else: - comp = kwargs.pop(mellon_comp_name) - if comp: - params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) - - params_run = {} - for inp_name, inp_param in self.config.input_params.items(): - logger.debug(f"input: {inp_name}") - mellon_inp_name = self.name_mapping.get(inp_name, inp_name) - if mellon_inp_name in kwargs: - if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: - inp = kwargs[mellon_inp_name].pop(inp_name) - else: - inp = kwargs.pop(mellon_inp_name) - if inp is not None: - params_run[inp_name] = inp - - return_output_names = list(self.config.output_params.keys()) - - return params_components, params_run, return_output_names - - def execute(self, **kwargs): - params_components, params_run, return_output_names = self.process_inputs(**kwargs) - - self.pipeline.update_components(**params_components) - output = self.pipeline(**params_run, output=return_output_names) - return output diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py index ae4ec4799fbc..2b01a5b5a4b5 100644 --- a/src/diffusers/modular_pipelines/qwenimage/__init__.py +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -21,27 +21,27 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["encoders"] = ["QwenImageTextEncoderStep"] - _import_structure["modular_blocks"] = [ - "ALL_BLOCKS", + _import_structure["modular_blocks_qwenimage"] = [ "AUTO_BLOCKS", - "CONTROLNET_BLOCKS", - "EDIT_AUTO_BLOCKS", - "EDIT_BLOCKS", - "EDIT_INPAINT_BLOCKS", - "EDIT_PLUS_AUTO_BLOCKS", - "EDIT_PLUS_BLOCKS", - "IMAGE2IMAGE_BLOCKS", - "INPAINT_BLOCKS", - "TEXT2IMAGE_BLOCKS", "QwenImageAutoBlocks", + ] + _import_structure["modular_blocks_qwenimage_edit"] = [ + "EDIT_AUTO_BLOCKS", "QwenImageEditAutoBlocks", + ] + _import_structure["modular_blocks_qwenimage_edit_plus"] = [ + "EDIT_PLUS_AUTO_BLOCKS", "QwenImageEditPlusAutoBlocks", ] + _import_structure["modular_blocks_qwenimage_layered"] = [ + "LAYERED_AUTO_BLOCKS", + "QwenImageLayeredAutoBlocks", + ] _import_structure["modular_pipeline"] = [ "QwenImageEditModularPipeline", "QwenImageEditPlusModularPipeline", "QwenImageModularPipeline", + "QwenImageLayeredModularPipeline", ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -51,28 +51,26 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .encoders import ( - QwenImageTextEncoderStep, - ) - from .modular_blocks import ( - ALL_BLOCKS, + from .modular_blocks_qwenimage import ( AUTO_BLOCKS, - CONTROLNET_BLOCKS, - EDIT_AUTO_BLOCKS, - EDIT_BLOCKS, - EDIT_INPAINT_BLOCKS, - EDIT_PLUS_AUTO_BLOCKS, - EDIT_PLUS_BLOCKS, - IMAGE2IMAGE_BLOCKS, - INPAINT_BLOCKS, - TEXT2IMAGE_BLOCKS, QwenImageAutoBlocks, + ) + from .modular_blocks_qwenimage_edit import ( + EDIT_AUTO_BLOCKS, QwenImageEditAutoBlocks, + ) + from .modular_blocks_qwenimage_edit_plus import ( + EDIT_PLUS_AUTO_BLOCKS, QwenImageEditPlusAutoBlocks, ) + from .modular_blocks_qwenimage_layered import ( + LAYERED_AUTO_BLOCKS, + QwenImageLayeredAutoBlocks, + ) from .modular_pipeline import ( QwenImageEditModularPipeline, QwenImageEditPlusModularPipeline, + QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) else: diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 0e470332c6f4..80a379da6be0 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -23,7 +23,7 @@ from ...utils.torch_utils import randn_tensor, unwrap_module from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift @@ -113,10 +113,45 @@ def get_timesteps(scheduler, num_inference_steps, strength): return timesteps, num_inference_steps - t_start -# Prepare Latents steps +# ==================== +# 1. PREPARE LATENTS +# ==================== +# auto_docstring class QwenImagePrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise for the generation process + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + model_name = "qwenimage" @property @@ -132,28 +167,20 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents"), - InputParam(name="height"), - InputParam(name="width"), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="generator"), - InputParam( - name="batch_size", - required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", - ), - InputParam( - name="dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs, can be generated in input step.", - ), + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), OutputParam( name="latents", type_hint=torch.Tensor, @@ -207,7 +234,150 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring +class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks): + """ + Prepare initial random noise (B, layers+1, C, H, W) for the generation process + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + height (`int`): + if not set, updated to default value + width (`int`): + if not set, updated to default value + latents (`Tensor`): + The initial latents to use for the denoising process + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("latents"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("layers"), + InputParam.template("num_images_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="height", type_hint=int, description="if not set, updated to default value"), + OutputParam(name="width", type_hint=int, description="if not set, updated to default value"), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs( + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + ) + + device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + # we can update the height and width here since it's used to generate the initial + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) + + shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + block_state.latents = components.pachifier.pack_latents(block_state.latents) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks): + """ + Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, + prepare_latents. Both noise and image latents should alreadybe patchified. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + """ + model_name = "qwenimage" @property @@ -229,12 +399,7 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The initial random noised, can be generated in prepare latent step.", ), - InputParam( - name="image_latents", - required=True, - type_hint=torch.Tensor, - description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.", - ), + InputParam.template("image_latents", note="Can be generated from vae encoder and updated in input step."), InputParam( name="timesteps", required=True, @@ -251,6 +416,11 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=torch.Tensor, description="The initial random noised used for inpainting denoising.", ), + OutputParam( + name="latents", + type_hint=torch.Tensor, + description="The scaled noisy latents to use for inpainting/image-to-image denoising.", + ), ] @staticmethod @@ -288,7 +458,29 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks): + """ + Step that creates mask latents from preprocessed mask_image by interpolating to latent space. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage" @property @@ -310,9 +502,9 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The processed mask to use for the inpainting process.", ), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="dtype", required=True), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("dtype"), ] @property @@ -351,15 +543,37 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -# Set Timesteps steps +# ==================== +# 2. SET TIMESTEPS +# ==================== +# auto_docstring class QwenImageSetTimestepsStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The initial random noised latents for the denoising process. Can be generated in prepare latents step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process + """ + model_name = "qwenimage" @property def description(self) -> str: - return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step." + return "Step that sets the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step." @property def expected_components(self) -> List[ComponentSpec]: @@ -370,13 +584,13 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), InputParam( name="latents", required=True, type_hint=torch.Tensor, - description="The latents to use for the denoising process, used to calculate the image sequence length.", + description="The initial random noised latents for the denoising process. Can be generated in prepare latents step.", ), ] @@ -420,12 +634,117 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring +class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks): + """ + Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process." + ), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # Layered-specific mu calculation + base_seqlen = 256 * 256 / 16 / 16 # = 256 + mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5 + + # Default sigmas if not provided + sigmas = ( + np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + if block_state.sigmas is None + else block_state.sigmas + ) + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare + latents step. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + latents (`Tensor`): + The latents to use for the denoising process. Can be generated in prepare latents step. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + + Outputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. + num_inference_steps (`int`): + The number of denoising steps to perform at inference time. Updated based on strength. + """ + model_name = "qwenimage" @property def description(self) -> str: - return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step." + return "Step that sets the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step." @property def expected_components(self) -> List[ComponentSpec]: @@ -436,15 +755,15 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), InputParam( - name="latents", + "latents", required=True, type_hint=torch.Tensor, - description="The latents to use for the denoising process, used to calculate the image sequence length.", + description="The latents to use for the denoising process. Can be generated in prepare latents step.", ), - InputParam(name="strength", default=0.9), + InputParam.template("strength", default=0.9), ] @property @@ -453,7 +772,12 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( name="timesteps", type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + description="The timesteps to use for the denoising process.", + ), + OutputParam( + name="num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time. Updated based on strength.", ), ] @@ -493,12 +817,36 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -# other inputs for denoiser +# ==================== +# 3. OTHER INPUTS FOR DENOISER +# ==================== ## RoPE inputs for denoiser +# auto_docstring class QwenImageRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + """ + model_name = "qwenimage" @property @@ -510,11 +858,11 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -522,9 +870,195 @@ def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( name="img_shapes", + kwargs_type="denoiser_input_fields", type_hint=List[List[Tuple[int, int, int]]], description="The shapes of the images latents, used for RoPE calculation", ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + ] + ] * block_state.batch_size + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after + prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`int`): + The height of the reference image. Can be generated in input step. + image_width (`int`): + The width of the reference image. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the images latents, used for RoPE calculation + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=int, + description="The height of the reference image. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=int, + description="The width of the reference image. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=List[List[Tuple[int, int, int]]], + description="The shapes of the images latents, used for RoPE calculation", + ), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # for edit, image size can be different from the target size (height/width) + block_state.img_shapes = [ + [ + ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ), + ( + 1, + block_state.image_height // components.vae_scale_factor // 2, + block_state.image_width // components.vae_scale_factor // 2, + ), + ] + ] * block_state.batch_size + + self.set_block_state(state, block_state) + + return components, state + + +# auto_docstring +class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus. + Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images. Should be placed + after prepare_latents step. + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_height (`List`): + The heights of the reference images. Can be generated in input step. + image_width (`List`): + The widths of the reference images. Can be generated in input step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n" + "Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n" + "Should be placed after prepare_latents step." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("batch_size"), + InputParam( + name="image_height", + required=True, + type_hint=List[int], + description="The heights of the reference images. Can be generated in input step.", + ), + InputParam( + name="image_width", + required=True, + type_hint=List[int], + description="The widths of the reference images. Can be generated in input step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="img_shapes", + kwargs_type="denoiser_input_fields", + type_hint=List[List[Tuple[int, int, int]]], + description="The shapes of the image latents, used for RoPE calculation", + ), OutputParam( name="txt_seq_lens", kwargs_type="denoiser_input_fields", @@ -542,15 +1076,19 @@ def intermediate_outputs(self) -> List[OutputParam]: def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + vae_scale_factor = components.vae_scale_factor + + # Edit Plus: image_height and image_width are lists block_state.img_shapes = [ [ - ( - 1, - block_state.height // components.vae_scale_factor // 2, - block_state.width // components.vae_scale_factor // 2, - ) + (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), + *[ + (1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2) + for img_height, img_width in zip(block_state.image_height, block_state.image_width) + ], ] ] * block_state.batch_size + block_state.txt_seq_lens = ( block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None ) @@ -565,23 +1103,54 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): - model_name = "qwenimage" +# auto_docstring +class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks): + """ + Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step + + Inputs: + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + img_shapes (`List`): + The shapes of the image latents, used for RoPE calculation + txt_seq_lens (`List`): + The sequence lengths of the prompt embeds, used for RoPE calculation + negative_txt_seq_lens (`List`): + The sequence lengths of the negative prompt embeds, used for RoPE calculation + additional_t_cond (`Tensor`): + The additional t cond, used for RoPE calculation + """ + + model_name = "qwenimage-layered" @property def description(self) -> str: - return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step" + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step" + ) @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="batch_size", required=True), - InputParam(name="image_height", required=True), - InputParam(name="image_width", required=True), - InputParam(name="height", required=True), - InputParam(name="width", required=True), - InputParam(name="prompt_embeds_mask"), - InputParam(name="negative_prompt_embeds_mask"), + InputParam.template("batch_size"), + InputParam.template("layers"), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds_mask"), ] @property @@ -590,42 +1159,46 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( name="img_shapes", type_hint=List[List[Tuple[int, int, int]]], - description="The shapes of the images latents, used for RoPE calculation", + kwargs_type="denoiser_input_fields", + description="The shapes of the image latents, used for RoPE calculation", ), OutputParam( name="txt_seq_lens", - kwargs_type="denoiser_input_fields", type_hint=List[int], + kwargs_type="denoiser_input_fields", description="The sequence lengths of the prompt embeds, used for RoPE calculation", ), OutputParam( name="negative_txt_seq_lens", - kwargs_type="denoiser_input_fields", type_hint=List[int], + kwargs_type="denoiser_input_fields", description="The sequence lengths of the negative prompt embeds, used for RoPE calculation", ), + OutputParam( + name="additional_t_cond", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="The additional t cond, used for RoPE calculation", + ), ] - def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # for edit, image size can be different from the target size (height/width) + device = components._execution_device - block_state.img_shapes = [ - [ - ( - 1, - block_state.height // components.vae_scale_factor // 2, - block_state.width // components.vae_scale_factor // 2, - ), - ( - 1, - block_state.image_height // components.vae_scale_factor // 2, - block_state.image_width // components.vae_scale_factor // 2, - ), - ] - ] * block_state.batch_size + # All shapes are the same for Layered + shape = ( + 1, + block_state.height // components.vae_scale_factor // 2, + block_state.width // components.vae_scale_factor // 2, + ) + + # layers+1 output shapes + 1 condition shape (all same) + block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size + # txt_seq_lens block_state.txt_seq_lens = ( block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None ) @@ -635,13 +1208,41 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - else None ) - self.set_block_state(state, block_state) + block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long) + self.set_block_state(state, block_state) return components, state ## ControlNet inputs for denoiser + + +# auto_docstring class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): + """ + step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step. + + Components: + controlnet (`QwenImageControlNetModel`) + + Inputs: + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + + Outputs: + controlnet_keep (`List`): + The controlnet keep values + """ + model_name = "qwenimage" @property @@ -657,12 +1258,17 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), - InputParam("control_image_latents", required=True), + InputParam.template("control_guidance_start"), + InputParam.template("control_guidance_end"), + InputParam.template("controlnet_conditioning_scale"), InputParam( - "timesteps", + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam( + name="timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 26417162deee..1adbf6bdd355 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Any, Dict, List -import numpy as np -import PIL import torch from ...configuration_utils import FrozenDict @@ -24,23 +22,46 @@ from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier logger = logging.get_logger(__name__) -class QwenImageDecoderStep(ModularPipelineBlocks): +# after denoising loop (unpack latents) + + +# auto_docstring +class QwenImageAfterDenoiseStep(ModularPipelineBlocks): + """ + Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, + channels, 1, height, width) + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + latents (`Tensor`): + The latents to decode, can be generated in the denoise step. + + Outputs: + latents (`Tensor`): + The denoisedlatents unpacked to B, C, 1, H, W + """ + model_name = "qwenimage" @property def description(self) -> str: - return "Step that decodes the latents to images" + return "Step that unpack the latents from 3D tensor (batch_size, sequence_length, channels) into 5D tensor (batch_size, channels, 1, height, width)" @property def expected_components(self) -> List[ComponentSpec]: components = [ - ComponentSpec("vae", AutoencoderKLQwenImage), ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), ] @@ -49,35 +70,170 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="height", required=True), - InputParam(name="width", required=True), + InputParam.template("height", required=True), + InputParam.template("width", required=True), InputParam( name="latents", required=True, type_hint=torch.Tensor, - description="The latents to decode, can be generated in the denoise step", + description="The latents to decode, can be generated in the denoise step.", ), ] @property - def intermediate_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], - description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", - ) + name="latents", type_hint=torch.Tensor, description="The denoisedlatents unpacked to B, C, 1, H, W" + ), ] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular vae_scale_factor = components.vae_scale_factor block_state.latents = components.pachifier.unpack_latents( block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + + Outputs: + latents (`Tensor`): + Denoised latents. (unpacked to B, C, layers+1, H, W) + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step.", + ), + InputParam.template("height", required=True), + InputParam.template("width", required=True), + InputParam.template("layers"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("latents", note="unpacked to B, C, layers+1, H, W"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W) + block_state.latents = components.pachifier.unpack_latents( + block_state.latents, + block_state.height, + block_state.width, + block_state.layers, + components.vae_scale_factor, + ) + + self.set_block_state(state, block_state) + return components, state + + +# decode step + + +# auto_docstring +class QwenImageDecoderStep(ModularPipelineBlocks): + """ + Step that decodes the latents to images + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Step that decodes the latents to images" + + @property + def expected_components(self) -> List[ComponentSpec]: + components = [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ] + + return components + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images", note="tensor output of the vae decoder.")] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular + if block_state.latents.ndim == 4: + block_state.latents = block_state.latents.unsqueeze(dim=1) + elif block_state.latents.ndim != 5: + raise ValueError( + f"expect latents to be a 4D or 5D tensor but got: {block_state.latents.shape}. Please make sure the latents are unpacked before decode step." + ) block_state.latents = block_state.latents.to(components.vae.dtype) latents_mean = ( @@ -95,7 +251,126 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring +class QwenImageLayeredDecoderStep(ModularPipelineBlocks): + """ + Decode unpacked latents (B, C, layers+1, H, W) into layer images. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Decode unpacked latents (B, C, layers+1, H, W) into layer images." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLQwenImage), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name="latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise step.", + ), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("images"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents + + # 1. VAE normalization + latents = latents.to(components.vae.dtype) + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + + # 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W) + b, c, f, h, w = latents.shape + # 3. Remove first frame (composite), keep layers frames + latents = latents[:, :, 1:] + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) + + # 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W) + image = components.vae.decode(latents, return_dict=False)[0] + image = image.squeeze(2) + + # 5. Postprocess - returns flat list of B*layers images + image = components.image_processor.postprocess(image, output_type=block_state.output_type) + + # 6. Chunk into list per batch item + images = [] + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) + + block_state.images = images + + self.set_block_state(state, block_state) + return components, state + + +# postprocess the decoded images + + +# auto_docstring class QwenImageProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" @property @@ -116,15 +391,19 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("images", required=True, description="the generated image from decoders step"), InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", ), + InputParam.template("output_type"), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images")] + @staticmethod def check_inputs(output_type): if output_type not in ["pil", "np", "pt"]: @@ -145,7 +424,28 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks): + """ + postprocess the generated image, optional apply the mask overally to the original image.. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + images (`Tensor`): + the generated image tensor from decoders step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" @property @@ -166,16 +466,24 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("images", required=True, description="the generated image from decoders step"), InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", + name="images", + required=True, + type_hint=torch.Tensor, + description="the generated image tensor from decoders step", + ), + InputParam.template("output_type"), + InputParam( + name="mask_overlay_kwargs", + type_hint=Dict[str, Any], + description="The kwargs for the postprocess step to apply the mask overlay. generated in InpaintProcessImagesInputStep.", ), - InputParam("mask_overlay_kwargs"), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam.template("images")] + @staticmethod def check_inputs(output_type, mask_overlay_kwargs): if output_type not in ["pil", "np", "pt"]: diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index 49acd2dc0295..8579c9843a89 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import List, Tuple import torch @@ -28,7 +29,12 @@ logger = logging.get_logger(__name__) +# ==================== +# 1. LOOP STEPS (run at each denoising step) +# ==================== + +# loop step:before denoiser class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks): model_name = "qwenimage" @@ -44,7 +50,7 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam( - "latents", + name="latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", @@ -60,7 +66,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks): - model_name = "qwenimage" + model_name = "qwenimage-edit" @property def description(self) -> str: @@ -74,17 +80,12 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam( - "latents", + name="latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.", - ), + InputParam.template("image_latents"), ] @torch.no_grad() @@ -128,29 +129,12 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", ), + InputParam.template("controlnet_conditioning_scale", note="updated in prepare_controlnet_inputs step."), InputParam( - "controlnet_conditioning_scale", - type_hint=float, - description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", - ), - InputParam( - "controlnet_keep", + name="controlnet_keep", required=True, type_hint=List[float], - description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description=( - "All conditional model inputs for the denoiser. " - "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens." - ), + description="The controlnet keep values. Can be generated in prepare_controlnet_inputs step.", ), ] @@ -176,7 +160,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState img_shapes=block_state.img_shapes, encoder_hidden_states=block_state.prompt_embeds, encoder_hidden_states_mask=block_state.prompt_embeds_mask, - txt_seq_lens=block_state.txt_seq_lens, return_dict=False, ) @@ -185,6 +168,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState return components, block_state +# loop step:denoiser class QwenImageLoopDenoiser(ModularPipelineBlocks): model_name = "qwenimage" @@ -211,28 +195,13 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", required=True, type_hint=List[Tuple[int, int]], - description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.", + description="The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.", ), ] @@ -247,12 +216,15 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) @@ -264,7 +236,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, timestep=block_state.timestep / 1000, - img_shapes=block_state.img_shapes, attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, @@ -284,7 +255,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState class QwenImageEditLoopDenoiser(ModularPipelineBlocks): - model_name = "qwenimage" + model_name = "qwenimage-edit" @property def description(self) -> str: @@ -309,23 +280,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The latents to use for the denoising process. Can be generated in prepare_latents step.", - ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.template("attention_kwargs"), + InputParam.template("denoiser_input_fields"), InputParam( "img_shapes", required=True, @@ -345,12 +301,15 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState getattr(block_state, "prompt_embeds_mask", None), getattr(block_state, "negative_prompt_embeds_mask", None), ), - "txt_seq_lens": ( - getattr(block_state, "txt_seq_lens", None), - getattr(block_state, "negative_txt_seq_lens", None), - ), } + transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) + additional_cond_kwargs = {} + for field_name, field_value in block_state.denoiser_input_fields.items(): + if field_name in transformer_args and field_name not in guider_inputs: + additional_cond_kwargs[field_name] = field_value + block_state.additional_cond_kwargs.update(additional_cond_kwargs) + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) @@ -362,7 +321,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, timestep=block_state.timestep / 1000, - img_shapes=block_state.img_shapes, attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, @@ -384,6 +342,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState return components, block_state +# loop step:after denoiser class QwenImageLoopAfterDenoiser(ModularPipelineBlocks): model_name = "qwenimage" @@ -404,7 +363,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."), + OutputParam.template("latents"), ] @torch.no_grad() @@ -445,24 +404,19 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam( - "image_latents", - required=True, - type_hint=torch.Tensor, - description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.", - ), + InputParam.template("image_latents"), InputParam( "initial_noise", required=True, type_hint=torch.Tensor, description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.", ), - InputParam( - "timesteps", - required=True, - type_hint=torch.Tensor, - description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", - ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("latents"), ] @torch.no_grad() @@ -481,6 +435,9 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState return components, block_state +# ==================== +# 2. DENOISE LOOP WRAPPER: define the denoising loop logic +# ==================== class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "qwenimage" @@ -501,17 +458,12 @@ def loop_expected_components(self) -> List[ComponentSpec]: def loop_inputs(self) -> List[InputParam]: return [ InputParam( - "timesteps", + name="timesteps", required=True, type_hint=torch.Tensor, description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - "num_inference_steps", - required=True, - type_hint=int, - description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", - ), + InputParam.template("num_inference_steps", required=True), ] @torch.no_grad() @@ -537,8 +489,50 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -# composing the denoising loops +# ==================== +# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps +# ==================== + + +# Qwen Image (text2image, image2image) + + +# auto_docstring class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2image and image2image tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopDenoiser, @@ -549,8 +543,8 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper): @property def description(self) -> str: return ( - "Denoise step that iteratively denoise the latents. \n" - "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "Denoise step that iteratively denoise the latents.\n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method\n" "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" " - `QwenImageLoopBeforeDenoiser`\n" " - `QwenImageLoopDenoiser`\n" @@ -559,8 +553,49 @@ def description(self) -> str: ) -# composing the inpainting denoising loops +# Qwen Image (inpainting) +# auto_docstring class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopDenoiser, @@ -583,8 +618,49 @@ def description(self) -> str: ) -# composing the controlnet denoising loops +# Qwen Image (text2image, image2image) with controlnet +# auto_docstring class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports text2img/img2img tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopBeforeDenoiserControlNet, @@ -607,8 +683,56 @@ def description(self) -> str: ) -# composing the controlnet denoising loops +# Qwen Image (inpainting) with controlnet +# auto_docstring class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageLoopBeforeDenoiser` + - `QwenImageLoopBeforeDenoiserControlNet` + - `QwenImageLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks with controlnet for QwenImage. + + Components: + guider (`ClassifierFreeGuidance`) controlnet (`QwenImageControlNetModel`) transformer + (`QwenImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + control_image_latents (`Tensor`): + The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.) + controlnet_keep (`List`): + The controlnet keep values. Can be generated in prepare_controlnet_inputs step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" block_classes = [ QwenImageLoopBeforeDenoiser, QwenImageLoopBeforeDenoiserControlNet, @@ -639,8 +763,44 @@ def description(self) -> str: ) -# composing the denoising loops +# Qwen Image Edit (image2image) +# auto_docstring class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, QwenImageEditLoopDenoiser, @@ -661,7 +821,49 @@ def description(self) -> str: ) +# Qwen Image Edit (inpainting) +# auto_docstring class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + - `QwenImageLoopAfterDenoiserInpaint` + This block supports inpainting tasks for QwenImage Edit. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + mask (`Tensor`): + The mask to use for the inpainting process. Can be generated in inpaint prepare latents step. + initial_noise (`Tensor`): + The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" block_classes = [ QwenImageEditLoopBeforeDenoiser, QwenImageEditLoopDenoiser, @@ -682,3 +884,61 @@ def description(self) -> str: " - `QwenImageLoopAfterDenoiserInpaint`\n" "This block supports inpainting tasks for QwenImage Edit." ) + + +# Qwen Image Layered (image2image) +# auto_docstring +class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper): + """ + Denoise step that iteratively denoise the latents. + Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method At each iteration, it runs blocks + defined in `sub_blocks` sequencially: + - `QwenImageEditLoopBeforeDenoiser` + - `QwenImageEditLoopDenoiser` + - `QwenImageLoopAfterDenoiser` + This block supports QwenImage Layered. + + Components: + guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + num_inference_steps (`int`): + The number of denoising steps. + latents (`Tensor`): + The initial latents to use for the denoising process. Can be generated in prepare_latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + img_shapes (`List`): + The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageEditLoopBeforeDenoiser, + QwenImageEditLoopDenoiser, + QwenImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `QwenImageEditLoopBeforeDenoiser`\n" + " - `QwenImageEditLoopDenoiser`\n" + " - `QwenImageLoopAfterDenoiser`\n" + "This block supports QwenImage Layered." + ) diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 3b56981e5290..5e1821cca5c0 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Text and VAE encoder blocks for QwenImage pipelines. +""" + from typing import Dict, List, Optional, Union import PIL @@ -26,8 +30,19 @@ from ...utils import logging from ...utils.torch_utils import unwrap_module from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import QwenImageModularPipeline +from .prompt_templates import ( + QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_EDIT_PROMPT_TEMPLATE, + QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, + QWENIMAGE_LAYERED_CAPTION_PROMPT_CN, + QWENIMAGE_LAYERED_CAPTION_PROMPT_EN, + QWENIMAGE_PROMPT_TEMPLATE, + QWENIMAGE_PROMPT_TEMPLATE_START_IDX, +) logger = logging.get_logger(__name__) @@ -45,8 +60,8 @@ def get_qwen_prompt_embeds( text_encoder, tokenizer, prompt: Union[str, List[str]] = None, - prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - prompt_template_encode_start_idx: int = 34, + prompt_template_encode: str = QWENIMAGE_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_PROMPT_TEMPLATE_START_IDX, tokenizer_max_length: int = 1024, device: Optional[torch.device] = None, ): @@ -86,8 +101,8 @@ def get_qwen_prompt_embeds_edit( processor, prompt: Union[str, List[str]] = None, image: Optional[torch.Tensor] = None, - prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", - prompt_template_encode_start_idx: int = 64, + prompt_template_encode: str = QWENIMAGE_EDIT_PROMPT_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -133,9 +148,9 @@ def get_qwen_prompt_embeds_edit_plus( processor, prompt: Union[str, List[str]] = None, image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None, - prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>", - prompt_template_encode_start_idx: int = 64, + prompt_template_encode: str = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE, + img_template_encode: str = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE, + prompt_template_encode_start_idx: int = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX, device: Optional[torch.device] = None, ): prompt = [prompt] if isinstance(prompt, str) else prompt @@ -241,33 +256,50 @@ def encode_vae_image( return image_latents -class QwenImageEditResizeDynamicStep(ModularPipelineBlocks): - model_name = "qwenimage" +# ==================== +# 1. RESIZE +# ==================== +# In QwenImage pipelines, resize is a separate step because the resized image is used in VL encoding and vae encoder blocks: +# +# image (PIL.Image.Image) +# │ +# ▼ +# resized_image ([PIL.Image.Image]) +# │ +# ├──► text_encoder ──► prompt_embeds, prompt_embeds_mask +# │ (VL encoding needs the resized image for vision-language fusion) +# │ +# └──► image_processor ──► processed_image (torch.Tensor, pixel space) +# │ +# ▼ +# vae_encoder ──► image_latents (torch.Tensor, latent space) +# +# In most of our other pipelines, resizing is done as part of the image preprocessing step. +# ==================== - def __init__(self, input_name: str = "image", output_name: str = "resized_image"): - """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. - This block resizes an input image tensor and exposes the resized result under configurable input and output - names. Use this when you need to wire the resize step to different image fields (e.g., "image", - "control_image") +# auto_docstring +class QwenImageEditResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to target area while maintaining the aspect ratio. - Args: - input_name (str, optional): Name of the image field to read from the - pipeline state. Defaults to "image". - output_name (str, optional): Name of the resized image field to write - back to the pipeline state. Defaults to "resized_image". - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self._image_input_name = input_name - self._resized_image_output_name = output_name - super().__init__() + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + + Outputs: + resized_image (`List`): + The resized images + """ + + model_name = "qwenimage-edit" @property def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio." + return "Image Resize step that resize the image to target area while maintaining the aspect ratio." @property def expected_components(self) -> List[ComponentSpec]: @@ -282,9 +314,89 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: + return [InputParam.template("image")] + + @property + def intermediate_outputs(self) -> List[OutputParam]: return [ + OutputParam( + name="resized_image", + type_hint=List[PIL.Image.Image], + description="The resized images", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = block_state.image + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if is_valid_image(images): + images = [images] + + image_width, image_height = images[0].size + calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + + resized_images = [ + components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) + for image in images + ] + + block_state.resized_image = resized_images + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageLayeredResizeStep(ModularPipelineBlocks): + """ + Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while + maintaining the aspect ratio. + + Components: + image_resize_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + + Outputs: + resized_image (`List`): + The resized images + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to a target area (defined by the resolution parameter from user) while maintaining the aspect ratio." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("image"), InputParam( - name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" + name="resolution", + default=640, + type_hint=int, + description="The target area to resize the image to, can be 1024 or 640", ), ] @@ -292,15 +404,24 @@ def inputs(self) -> List[InputParam]: def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" - ), + name="resized_image", + type_hint=List[PIL.Image.Image], + description="The resized images", + ) ] + @staticmethod + def check_inputs(resolution: int): + if resolution not in [1024, 640]: + raise ValueError(f"Resolution must be 1024 or 640 but is {resolution}") + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - images = getattr(block_state, self._image_input_name) + self.check_inputs(resolution=block_state.resolution) + + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") @@ -309,59 +430,79 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): images = [images] image_width, image_height = images[0].size - calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + target_area = block_state.resolution * block_state.resolution + calculated_width, calculated_height, _ = calculate_dimensions(target_area, image_width / image_height) resized_images = [ components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width) for image in images ] - setattr(block_state, self._resized_image_output_name, resized_images) + block_state.resized_image = resized_images self.set_block_state(state, block_state) return components, state -class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep): - model_name = "qwenimage" +# auto_docstring +class QwenImageEditPlusResizeStep(ModularPipelineBlocks): + """ + Resize images for QwenImage Edit Plus pipeline. + Produces two outputs: resized_image (1024x1024) for VAE encoding, resized_cond_image (384x384) for VL text + encoding. Each image is resized independently based on its own aspect ratio. - def __init__( - self, - input_name: str = "image", - output_name: str = "resized_image", - vae_image_output_name: str = "vae_image", - ): - """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. + Components: + image_resize_processor (`VaeImageProcessor`) - This block resizes an input image or a list input images and exposes the resized result under configurable - input and output names. Use this when you need to wire the resize step to different image fields (e.g., - "image", "control_image") + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. - Args: - input_name (str, optional): Name of the image field to read from the - pipeline state. Defaults to "image". - output_name (str, optional): Name of the resized image field to write - back to the pipeline state. Defaults to "resized_image". - vae_image_output_name (str, optional): Name of the image field - to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus - processes the input image(s) differently for the VL and the VAE. - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError( - f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" - ) - self.condition_image_size = 384 * 384 - self._image_input_name = input_name - self._resized_image_output_name = output_name - self._vae_image_output_name = vae_image_output_name - super().__init__() + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + """ + + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return ( + "Resize images for QwenImage Edit Plus pipeline.\n" + "Produces two outputs: resized_image (1024x1024) for VAE encoding, " + "resized_cond_image (384x384) for VL text encoding.\n" + "Each image is resized independently based on its own aspect ratio." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_resize_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + # image + return [InputParam.template("image")] @property def intermediate_outputs(self) -> List[OutputParam]: - return super().intermediate_outputs + [ + return [ + OutputParam( + name="resized_image", + type_hint=List[PIL.Image.Image], + description="Images resized to 1024x1024 target area for VAE encoding", + ), OutputParam( - name=self._vae_image_output_name, + name="resized_cond_image", type_hint=List[PIL.Image.Image], - description="The images to be processed which will be further used by the VAE encoder.", + description="Images resized to 384x384 target area for VL text encoding", ), ] @@ -369,41 +510,194 @@ def intermediate_outputs(self) -> List[OutputParam]: def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - images = getattr(block_state, self._image_input_name) + images = block_state.image if not is_valid_image_imagelist(images): raise ValueError(f"Images must be image or list of images but are {type(images)}") - if ( - not isinstance(images, torch.Tensor) - and isinstance(images, PIL.Image.Image) - and not isinstance(images, list) - ): + if is_valid_image(images): images = [images] - # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s - condition_images = [] - vae_images = [] - for img in images: - image_width, image_height = img.size - condition_width, condition_height, _ = calculate_dimensions( - self.condition_image_size, image_width / image_height + # Resize each image independently based on its own aspect ratio + resized_images = [] + resized_cond_images = [] + for image in images: + image_width, image_height = image.size + + # For VAE encoder (1024x1024 target area) + vae_width, vae_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height) + resized_images.append(components.image_resize_processor.resize(image, height=vae_height, width=vae_width)) + + # For VL text encoder (384x384 target area) + vl_width, vl_height, _ = calculate_dimensions(384 * 384, image_width / image_height) + resized_cond_images.append( + components.image_resize_processor.resize(image, height=vl_height, width=vl_width) ) - condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width)) - vae_images.append(img) - setattr(block_state, self._resized_image_output_name, condition_images) - setattr(block_state, self._vae_image_output_name, vae_images) + block_state.resized_image = resized_images + block_state.resized_cond_image = resized_cond_images self.set_block_state(state, block_state) return components, state +# ==================== +# 2. GET IMAGE PROMPT +# ==================== + + +# auto_docstring +class QwenImageLayeredGetImagePromptStep(ModularPipelineBlocks): + """ + Auto-caption step that generates a text prompt from the input image if none is provided. + Uses the VL model (text_encoder) to generate a description of the image. If prompt is already provided, this step + passes through unchanged. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + resized_image (`Image`): + The image to generate caption from, should be resized use the resize step + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + + Outputs: + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + """ + + model_name = "qwenimage-layered" + + def __init__(self): + self.image_caption_prompt_en = QWENIMAGE_LAYERED_CAPTION_PROMPT_EN + self.image_caption_prompt_cn = QWENIMAGE_LAYERED_CAPTION_PROMPT_CN + super().__init__() + + @property + def description(self) -> str: + return ( + "Auto-caption step that generates a text prompt from the input image if none is provided.\n" + "Uses the VL model (text_encoder) to generate a description of the image.\n" + "If prompt is already provided, this step passes through unchanged." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template( + "prompt", required=False + ), # it is not required for qwenimage-layered, unlike other pipelines + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The image to generate caption from, should be resized use the resize step", + ), + InputParam( + name="use_en_prompt", + default=False, + type_hint=bool, + description="Whether to use English prompt template", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="prompt", + type_hint=str, + description="The prompt or prompts to guide image generation. If not provided, updated using image caption", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + # If prompt is empty or None, generate caption from image + if block_state.prompt is None or block_state.prompt == "" or block_state.prompt == " ": + if block_state.use_en_prompt: + caption_prompt = self.image_caption_prompt_en + else: + caption_prompt = self.image_caption_prompt_cn + + model_inputs = components.processor( + text=caption_prompt, + images=block_state.resized_image, + padding=True, + return_tensors="pt", + ).to(device) + + generated_ids = components.text_encoder.generate(**model_inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = components.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + + block_state.prompt = output_text.strip() + + self.set_block_state(state, block_state) + return components, state + + +# ==================== +# 3. TEXT ENCODER +# ==================== + + +# auto_docstring class QwenImageTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that generates text embeddings to guide the image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + model_name = "qwenimage" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_PROMPT_TEMPLATE_START_IDX + self.tokenizer_max_length = 1024 + super().__init__() + @property def description(self) -> str: - return "Text Encoder step that generate text_embeddings to guide the image generation" + return "Text Encoder step that generates text embeddings to guide the image generation." @property def expected_components(self) -> List[ComponentSpec]: @@ -418,54 +712,21 @@ def expected_components(self) -> List[ComponentSpec]: ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec( - name="prompt_template_encode", - default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", - ), - ConfigSpec(name="prompt_template_encode_start_idx", default=34), - ConfigSpec(name="tokenizer_max_length", default=1024), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), - InputParam( - name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024 - ), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length", default=1024), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -494,9 +755,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.tokenizer, prompt=block_state.prompt, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, - tokenizer_max_length=components.config.tokenizer_max_length, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, device=device, ) @@ -511,9 +772,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.tokenizer, prompt=negative_prompt, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, - tokenizer_max_length=components.config.tokenizer_max_length, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, + tokenizer_max_length=self.tokenizer_max_length, device=device, ) block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[ @@ -527,12 +788,45 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageEditTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image + generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_image (`Image`): + The image prompt to encode, should be resized using resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + model_name = "qwenimage" + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PROMPT_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX + super().__init__() + @property def description(self) -> str: - return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation" + return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation." @property def expected_components(self) -> List[ComponentSpec]: @@ -547,25 +841,15 @@ def expected_components(self) -> List[ComponentSpec]: ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [ - ConfigSpec( - name="prompt_template_encode", - default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n", - ), - ConfigSpec(name="prompt_template_encode_start_idx", default=64), - ] - @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam.template("prompt"), + InputParam.template("negative_prompt"), InputParam( name="resized_image", required=True, - type_hint=torch.Tensor, + type_hint=PIL.Image.Image, description="The image prompt to encode, should be resized using resize step", ), ] @@ -573,30 +857,10 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - name="prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The prompt embeddings", - ), - OutputParam( - name="prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The encoder attention mask", - ), - OutputParam( - name="negative_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings", - ), - OutputParam( - name="negative_prompt_embeds_mask", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="The negative prompt embeddings mask", - ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), ] @staticmethod @@ -624,8 +888,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.processor, prompt=block_state.prompt, image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -638,8 +902,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.processor, prompt=negative_prompt, image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + prompt_template_encode=self.prompt_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -647,23 +911,98 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state -class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep): - model_name = "qwenimage" +# auto_docstring +class QwenImageEditPlusTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together to generate text + embeddings for guiding image generation. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor (`Qwen2VLProcessor`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + resized_cond_image (`Tensor`): + The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using + resize step + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit-plus" + + def __init__(self): + self.prompt_template_encode = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE + self.img_template_encode = QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE + self.prompt_template_encode_start_idx = QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX + super().__init__() + + @property + def description(self) -> str: + return ( + "Text Encoder step for QwenImage Edit Plus that processes prompt and multiple images together " + "to generate text embeddings for guiding image generation." + ) @property - def expected_configs(self) -> List[ConfigSpec]: + def expected_components(self) -> List[ComponentSpec]: return [ - ConfigSpec( - name="prompt_template_encode", - default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration), + ComponentSpec("processor", Qwen2VLProcessor), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", ), - ConfigSpec( - name="img_template_encode", - default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>", + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam( + name="resized_cond_image", + required=True, + type_hint=torch.Tensor, + description="The image(s) to encode, can be a single image or list of images, should be resized to 384x384 using resize step", ), - ConfigSpec(name="prompt_template_encode_start_idx", default=64), ] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) @@ -676,10 +1015,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.processor, prompt=block_state.prompt, - image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - img_template_encode=components.config.img_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + image=block_state.resized_cond_image, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) @@ -692,10 +1031,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): components.text_encoder, components.processor, prompt=negative_prompt, - image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - img_template_encode=components.config.img_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + image=block_state.resized_cond_image, + prompt_template_encode=self.prompt_template_encode, + img_template_encode=self.img_template_encode, + prompt_template_encode_start_idx=self.prompt_template_encode_start_idx, device=device, ) ) @@ -704,12 +1043,46 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# ==================== +# 4. IMAGE PREPROCESS +# ==================== + + +# auto_docstring class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be + resized to the given height and width. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + model_name = "qwenimage" @property def description(self) -> str: - return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep." + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images will be resized to the given height and width." @property def expected_components(self) -> List[ComponentSpec]: @@ -725,19 +1098,26 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("mask_image", required=True), - InputParam("resized_image"), - InputParam("image"), - InputParam("height"), - InputParam("width"), - InputParam("padding_mask_crop"), + InputParam.template("mask_image"), + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("padding_mask_crop"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="processed_image"), - OutputParam(name="processed_mask_image"), + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), OutputParam( name="mask_overlay_kwargs", type_hint=Dict, @@ -757,23 +1137,107 @@ def check_inputs(height, width, vae_scale_factor): def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - if block_state.resized_image is None and block_state.image is None: - raise ValueError("resized_image and image cannot be None at the same time") + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width - if block_state.resized_image is None: - image = block_state.image - self.check_inputs( - height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( + components.image_mask_processor.preprocess( + image=block_state.image, + mask=block_state.mask_image, + height=height, + width=width, + padding_mask_crop=block_state.padding_mask_crop, ) - height = block_state.height or components.default_height - width = block_state.width or components.default_width - else: - width, height = block_state.resized_image[0].size - image = block_state.resized_image + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be + resized first. + + Components: + image_mask_processor (`InpaintProcessor`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + resized_image (`Image`): + The resized image. should be generated using a resize step + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images should be resized first." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_mask_processor", + InpaintProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("mask_image"), + InputParam( + name="resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The resized image. should be generated using a resize step", + ), + InputParam.template("padding_mask_crop"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="processed_image", type_hint=torch.Tensor, description="The processed image"), + OutputParam( + name="processed_mask_image", + type_hint=torch.Tensor, + description="The processed mask image", + ), + OutputParam( + name="mask_overlay_kwargs", + type_hint=Dict, + description="The kwargs for the postprocess step to apply the mask overlay", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = ( components.image_mask_processor.preprocess( - image=image, + image=block_state.resized_image, mask=block_state.mask_image, height=height, width=width, @@ -785,12 +1249,32 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +# auto_docstring class QwenImageProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. will resize the image to the given height and width. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage" @property def description(self) -> str: - return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep." + return "Image Preprocess step. will resize the image to the given height and width." @property def expected_components(self) -> List[ComponentSpec]: @@ -805,12 +1289,20 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + return [ + InputParam.template("image"), + InputParam.template("height"), + InputParam.template("width"), + ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="processed_image"), + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) ] @staticmethod @@ -825,22 +1317,85 @@ def check_inputs(height, width, vae_scale_factor): def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - if block_state.resized_image is None and block_state.image is None: - raise ValueError("resized_image and image cannot be None at the same time") + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + block_state.processed_image = components.image_processor.preprocess( + image=block_state.image, + height=height, + width=width, + ) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images needs to be resized first. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step - if block_state.resized_image is None: - image = block_state.image - self.check_inputs( - height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + Outputs: + processed_image (`Tensor`): + The processed image + """ + + model_name = "qwenimage-edit" + + @property + def description(self) -> str: + return "Image Preprocess step. Images needs to be resized first." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + name="resized_image", + required=True, + type_hint=List[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", ) - height = block_state.height or components.default_height - width = block_state.width or components.default_width - else: - width, height = block_state.resized_image[0].size - image = block_state.resized_image + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + width, height = block_state.resized_image[0].size block_state.processed_image = components.image_processor.preprocess( - image=image, + image=block_state.resized_image, height=height, width=width, ) @@ -849,105 +1404,169 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state -class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): +# auto_docstring +class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks): + """ + Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of + processed images. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + resized_image (`List`): + The resized image. should be generated using a resize step + + Outputs: + processed_image (`Tensor`): + The processed image + """ + model_name = "qwenimage-edit-plus" - vae_image_size = 1024 * 1024 @property def description(self) -> str: - return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing." + return "Image Preprocess step. Images can be resized first. If a list of images is provided, will return a list of processed images." + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] @property def inputs(self) -> List[InputParam]: - return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] + return [ + InputParam( + name="resized_image", + required=True, + type_hint=List[PIL.Image.Image], + description="The resized image. should be generated using a resize step", + ) + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="processed_image", + type_hint=torch.Tensor, + description="The processed image", + ) + ] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) - if block_state.vae_image is None and block_state.image is None: - raise ValueError("`vae_image` and `image` cannot be None at the same time") + image = block_state.resized_image - if block_state.vae_image is None: - image = block_state.image - self.check_inputs( - height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor - ) - height = block_state.height or components.default_height - width = block_state.width or components.default_width - block_state.processed_image = components.image_processor.preprocess( - image=image, height=height, width=width - ) - else: - width, height = block_state.vae_image[0].size - image = block_state.vae_image + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] - block_state.processed_image = components.image_processor.preprocess( - image=image, height=height, width=width + processed_images = [] + for img in image: + img_width, img_height = img.size + processed_images.append( + components.image_processor.preprocess(image=img, height=img_height, width=img_width) ) + if is_image_list: + block_state.processed_image = processed_images + else: + block_state.processed_image = processed_images[0] + self.set_block_state(state, block_state) return components, state -class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks): +# ==================== +# 5. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts processed_image into latent representations image_latents. + Handles both single images and lists of images with varied resolutions. + + Components: + vae (`AutoencoderKLQwenImage`) + + Inputs: + processed_image (`Tensor`): + The image tensor to encode + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. + """ + model_name = "qwenimage" def __init__( self, - input_name: str = "processed_image", - output_name: str = "image_latents", + input: Optional[InputParam] = None, + output: Optional[OutputParam] = None, ): """Initialize a VAE encoder step for converting images to latent representations. - Both the input and output names are configurable so this block can be configured to process to different image - inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents"). + Handles both single images and lists of images. When input is a list, outputs a list of latents. When input is + a single tensor, outputs a single latent tensor. Args: - input_name (str, optional): Name of the input image tensor. Defaults to "processed_image". - Examples: "processed_image" or "processed_control_image" - output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents". - Examples: "image_latents" or "control_image_latents" + input (InputParam, optional): Input parameter for the processed image. Defaults to "processed_image". + output (OutputParam, optional): Output parameter for the image latents. Defaults to "image_latents". + """ + if input is None: + input = InputParam( + name="processed_image", required=True, type_hint=torch.Tensor, description="The image tensor to encode" + ) - Examples: - # Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep() + if output is None: + output = OutputParam.template("image_latents") - # Custom input/output names for control image QwenImageVaeEncoderDynamicStep( - input_name="processed_control_image", output_name="control_image_latents" - ) - """ - self._image_input_name = input_name - self._image_latents_output_name = output_name + if not isinstance(input, InputParam): + raise ValueError(f"input must be InputParam but is {type(input)}") + if not isinstance(output, OutputParam): + raise ValueError(f"output must be OutputParam but is {type(output)}") + + self._input = input + self._output = output + self._image_input_name = input.name + self._image_latents_output_name = output.name super().__init__() @property def description(self) -> str: - return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + return ( + f"VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" + "Handles both single images and lists of images with varied resolutions." + ) @property def expected_components(self) -> List[ComponentSpec]: - components = [ - ComponentSpec("vae", AutoencoderKLQwenImage), - ] - return components + return [ComponentSpec("vae", AutoencoderKLQwenImage)] @property def inputs(self) -> List[InputParam]: - inputs = [ - InputParam(self._image_input_name, required=True), - InputParam("generator"), + return [ + self._input, # default is "processed_image" + InputParam.template("generator"), ] - return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - self._image_latents_output_name, - type_hint=torch.Tensor, - description="The latents representing the reference image", - ) - ] + return [self._output] # default is "image_latents" @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -957,16 +1576,26 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - dtype = components.vae.dtype image = getattr(block_state, self._image_input_name) + is_image_list = isinstance(image, list) + if not is_image_list: + image = [image] + + # Handle both single image and list of images + image_latents = [] + for img in image: + image_latents.append( + encode_vae_image( + image=img, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + ) + if not is_image_list: + image_latents = image_latents[0] - # Encode image into latents - image_latents = encode_vae_image( - image=image, - vae=components.vae, - generator=block_state.generator, - device=device, - dtype=dtype, - latent_channels=components.num_channels_latents, - ) setattr(block_state, self._image_latents_output_name, image_latents) self.set_block_state(state, block_state) @@ -974,7 +1603,30 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): + """ + VAE Encoder step that converts `control_image` into latent representations control_image_latents. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + model_name = "qwenimage" @property @@ -998,10 +1650,10 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam("control_image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("generator"), + InputParam.template("control_image"), + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("generator"), ] return inputs @@ -1083,3 +1735,52 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - self.set_block_state(state, block_state) return components, state + + +# ==================== +# 6. PERMUTE LATENTS +# ==================== + + +# auto_docstring +class QwenImageLayeredPermuteLatentsStep(ModularPipelineBlocks): + """ + Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing. + + Inputs: + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_latents (`Tensor`): + The latent representation of the input image. (permuted from [B, C, 1, H, W] to [B, 1, C, H, W]) + """ + + model_name = "qwenimage-layered" + + @property + def description(self) -> str: + return "Permute image latents from (B, C, 1, H, W) to (B, 1, C, H, W) for Layered packing." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam.template("image_latents"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam.template("image_latents", note="permuted from [B, C, 1, H, W] to [B, 1, C, H, W]"), + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Permute: (B, C, 1, H, W) -> (B, 1, C, H, W) + latents = block_state.image_latents + block_state.image_latents = latents.permute(0, 2, 1, 3, 4) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 2b229c040b89..818bbca5ed0a 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from ...models import QwenImageMultiControlNetModel from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier +from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier def repeat_tensor_to_batch_size( @@ -109,7 +109,44 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: in return height, width +# auto_docstring class QwenImageTextInputsStep(ModularPipelineBlocks): + """ + Text input processing step that standardizes text embeddings for the pipeline. + This step: + 1. Determines `batch_size` and `dtype` based on `prompt_embeds` + 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt) + + This block should be placed after all encoder steps to process the text embeddings before they are used in + subsequent pipeline steps. + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + """ + model_name = "qwenimage" @property @@ -129,26 +166,22 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"), - InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"), - InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"), - InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"), + InputParam.template("num_images_per_prompt"), + InputParam.template("prompt_embeds"), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask"), ] @property - def intermediate_outputs(self) -> List[str]: + def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - "batch_size", - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", - ), - OutputParam( - "dtype", - type_hint=torch.dtype, - description="Data type of model tensor inputs (determined by `prompt_embeds`)", - ), + OutputParam(name="batch_size", type_hint=int, description="The batch size of the prompt embeddings"), + OutputParam(name="dtype", type_hint=torch.dtype, description="The data type of the prompt embeddings"), + OutputParam.template("prompt_embeds", note="batch-expanded"), + OutputParam.template("prompt_embeds_mask", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds", note="batch-expanded"), + OutputParam.template("negative_prompt_embeds_mask", note="batch-expanded"), ] @staticmethod @@ -221,45 +254,289 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -class QwenImageInputsDynamicStep(ModularPipelineBlocks): +# auto_docstring +class QwenImageAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step that: + 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + model_name = "qwenimage" def __init__( self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, ): - """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + # by default, process `image_latents` + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), + ] + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> List[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), + ] + + # `height`/`width` are not new outputs, but they will be updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified and batch-expanded)", + ) + ) - This step handles multiple common tasks to prepare inputs for the denoising step: - 1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size - 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) - This is a dynamic block that allows you to configure which inputs to process. + return outputs - Args: - image_latent_inputs (List[str], optional): Names of image latent tensors to process. - These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or - list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"] - additional_batch_inputs (List[str], optional): - Names of additional conditional input tensors to expand batch size. These tensors will only have their - batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. - Defaults to []. Examples: ["processed_mask_image"] + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) - Examples: - # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue - # Configure to process multiple image latent inputs - QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"]) + # 1. Calculate height/width from latents and update if not provided + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify + image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue - # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, ) - """ + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# auto_docstring +class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step for Edit Plus that: + 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch + 2. For additional batch inputs: Expands batch dimensions to match final batch size + Height/width defaults to last image in the list. + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + + model_name = "qwenimage-edit-plus" + + def __init__( + self, + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, + ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + if not isinstance(image_latent_inputs, list): - image_latent_inputs = [image_latent_inputs] + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + if not isinstance(additional_batch_inputs, list): - additional_batch_inputs = [additional_batch_inputs] + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs @@ -267,79 +544,340 @@ def __init__( @property def description(self) -> str: - # Functionality section summary_section = ( - "Input processing step that:\n" - " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" - " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + "Input processing step for Edit Plus that:\n" + " 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size\n" + " Height/width defaults to last image in the list." ) - # Inputs info inputs_info = "" if self._image_latent_inputs or self._additional_batch_inputs: inputs_info = "\n\nConfigured inputs:" if self._image_latent_inputs: - inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" if self._additional_batch_inputs: - inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" - # Placement guidance placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." return summary_section + inputs_info + placement_section + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ] + @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="batch_size", required=True), - InputParam(name="height"), - InputParam(name="width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + InputParam.template("height"), + InputParam.template("width"), ] - # Add image latent inputs - for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) - - # Add additional batch inputs - for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + # default is `image_latents` + inputs += self._image_latent_inputs + self._additional_batch_inputs return inputs @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam(name="image_height", type_hint=int, description="The height of the image latents"), - OutputParam(name="image_width", type_hint=int, description="The width of the image latents"), + outputs = [ + OutputParam( + name="image_height", + type_hint=List[int], + description="The image heights calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=List[int], + description="The image widths calculated from the image latents dimension", + ), ] + # `height`/`width` are updated if any image latent inputs are provided + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # image latent inputs are modified in place (patchified, concatenated, and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified, concatenated, and batch-expanded)", + ) + ) + + # additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + is_list = isinstance(image_latent_tensor, list) + if not is_list: + image_latent_tensor = [image_latent_tensor] + + image_heights = [] + image_widths = [] + packed_image_latent_tensors = [] + + for i, img_latent_tensor in enumerate(image_latent_tensor): + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor) + image_heights.append(height) + image_widths.append(width) + + # 2. Patchify + img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor) + + # 3. Expand batch size + img_latent_tensor = repeat_tensor_to_batch_size( + input_name=f"{image_latent_input_name}[{i}]", + input_tensor=img_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + packed_image_latent_tensors.append(img_latent_tensor) + + # Concatenate all packed latents along dim=1 + packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1) + + # Output lists of heights/widths + block_state.image_height = image_heights + block_state.image_width = image_widths + + # Default height/width from last image + block_state.height = block_state.height or image_heights[-1] + block_state.width = block_state.width or image_widths[-1] + + setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + + # Process additional batch inputs (only batch expansion) + for input_param in self._additional_batch_inputs: + input_name = input_param.name + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +# same as QwenImageAdditionalInputsStep, but with layered pachifier. + + +# auto_docstring +class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks): + """ + Input processing step for Layered that: + 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch + size + 2. For additional batch inputs: Expands batch dimensions to match final batch size + + Configured inputs: + - Image latent inputs: ['image_latents'] + + This block should be placed after the encoder steps and the text input step. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + + model_name = "qwenimage-layered" + + def __init__( + self, + image_latent_inputs: Optional[List[InputParam]] = None, + additional_batch_inputs: Optional[List[InputParam]] = None, + ): + if image_latent_inputs is None: + image_latent_inputs = [InputParam.template("image_latents")] + if additional_batch_inputs is None: + additional_batch_inputs = [] + + if not isinstance(image_latent_inputs, list): + raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}") + else: + for input_param in image_latent_inputs: + if not isinstance(input_param, InputParam): + raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}") + + if not isinstance(additional_batch_inputs, list): + raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}") + else: + for input_param in additional_batch_inputs: + if not isinstance(input_param, InputParam): + raise ValueError( + f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}" + ) + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + summary_section = ( + "Input processing step for Layered that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}" + + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), + ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size"), + ] + # default is `image_latents` + + inputs += self._image_latent_inputs + self._additional_batch_inputs + + return inputs + + @property + def intermediate_outputs(self) -> List[OutputParam]: + outputs = [ + OutputParam( + name="image_height", + type_hint=int, + description="The image height calculated from the image latents dimension", + ), + OutputParam( + name="image_width", + type_hint=int, + description="The image width calculated from the image latents dimension", + ), ] + if len(self._image_latent_inputs) > 0: + outputs.append( + OutputParam(name="height", type_hint=int, description="if not provided, updated to image height") + ) + outputs.append( + OutputParam(name="width", type_hint=int, description="if not provided, updated to image width") + ) + + # Add outputs for image latent inputs (patchified with layered pachifier and batch-expanded) + for input_param in self._image_latent_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (patchified with layered pachifier and batch-expanded)", + ) + ) + + # Add outputs for additional batch inputs (batch-expanded only) + for input_param in self._additional_batch_inputs: + outputs.append( + OutputParam( + name=input_param.name, + type_hint=input_param.type_hint, + description=input_param.description + " (batch-expanded)", + ) + ) + + return outputs + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - # Process image latent inputs (height/width calculation, patchify, and batch expansion) - for image_latent_input_name in self._image_latent_inputs: + # Process image latent inputs + for input_param in self._image_latent_inputs: + image_latent_input_name = input_param.name image_latent_tensor = getattr(block_state, image_latent_input_name) if image_latent_tensor is None: continue - # 1. Calculate height/width from latents - height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) - block_state.height = block_state.height or height - block_state.width = block_state.width or width + # 1. Calculate height/width from latents and update if not provided + # Layered latents are (B, layers, C, H, W) + height = image_latent_tensor.shape[3] * components.vae_scale_factor + width = image_latent_tensor.shape[4] * components.vae_scale_factor + block_state.height = height + block_state.width = width if not hasattr(block_state, "image_height"): block_state.image_height = height if not hasattr(block_state, "image_width"): block_state.image_width = width - # 2. Patchify the image latent tensor + # 2. Patchify with layered pachifier image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) # 3. Expand batch size @@ -353,12 +891,12 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - setattr(block_state, image_latent_input_name, image_latent_tensor) # Process additional batch inputs (only batch expansion) - for input_name in self._additional_batch_inputs: + for input_param in self._additional_batch_inputs: + input_name = input_param.name input_tensor = getattr(block_state, input_name) if input_tensor is None: continue - # Only expand batch size input_tensor = repeat_tensor_to_batch_size( input_name=input_name, input_tensor=input_tensor, @@ -372,7 +910,34 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +# auto_docstring class QwenImageControlNetInputsStep(ModularPipelineBlocks): + """ + prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps. + + Inputs: + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + batch_size (`int`, *optional*, defaults to 1): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + + Outputs: + control_image_latents (`Tensor`): + The control image latents (patchified and batch-expanded). + height (`int`): + if not provided, updated to control image height + width (`int`): + if not provided, updated to control image width + """ + model_name = "qwenimage" @property @@ -382,11 +947,28 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="control_image_latents", required=True), - InputParam(name="batch_size", required=True), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="height"), - InputParam(name="width"), + InputParam( + name="control_image_latents", + required=True, + type_hint=torch.Tensor, + description="The control image latents to use for the denoising process. Can be generated in controlnet vae encoder step.", + ), + InputParam.template("batch_size"), + InputParam.template("num_images_per_prompt"), + InputParam.template("height"), + InputParam.template("width"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="control_image_latents", + type_hint=torch.Tensor, + description="The control image latents (patchified and batch-expanded).", + ), + OutputParam(name="height", type_hint=int, description="if not provided, updated to control image height"), + OutputParam(name="width", type_hint=int, description="if not provided, updated to control image width"), ] @torch.no_grad() diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py deleted file mode 100644 index 419894164389..000000000000 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ /dev/null @@ -1,1035 +0,0 @@ -# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ...utils import logging -from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks -from ..modular_pipeline_utils import InsertableDict -from .before_denoise import ( - QwenImageControlNetBeforeDenoiserStep, - QwenImageCreateMaskLatentsStep, - QwenImageEditRoPEInputsStep, - QwenImagePrepareLatentsStep, - QwenImagePrepareLatentsWithStrengthStep, - QwenImageRoPEInputsStep, - QwenImageSetTimestepsStep, - QwenImageSetTimestepsWithStrengthStep, -) -from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep -from .denoise import ( - QwenImageControlNetDenoiseStep, - QwenImageDenoiseStep, - QwenImageEditDenoiseStep, - QwenImageEditInpaintDenoiseStep, - QwenImageInpaintControlNetDenoiseStep, - QwenImageInpaintDenoiseStep, - QwenImageLoopBeforeDenoiserControlNet, -) -from .encoders import ( - QwenImageControlNetVaeEncoderStep, - QwenImageEditPlusProcessImagesInputStep, - QwenImageEditPlusResizeDynamicStep, - QwenImageEditPlusTextEncoderStep, - QwenImageEditResizeDynamicStep, - QwenImageEditTextEncoderStep, - QwenImageInpaintProcessImagesInputStep, - QwenImageProcessImagesInputStep, - QwenImageTextEncoderStep, - QwenImageVaeEncoderDynamicStep, -) -from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep - - -logger = logging.get_logger(__name__) - -# 1. QwenImage - -## 1.1 QwenImage/text2image - -#### QwenImage/decode -#### (standard decode step works for most tasks except for inpaint) -QwenImageDecodeBlocks = InsertableDict( - [ - ("decode", QwenImageDecoderStep()), - ("postprocess", QwenImageProcessImagesOutputStep()), - ] -) - - -class QwenImageDecodeStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageDecodeBlocks.values() - block_names = QwenImageDecodeBlocks.keys() - - @property - def description(self): - return "Decode step that decodes the latents to images and postprocess the generated image." - - -#### QwenImage/text2image presets -TEXT2IMAGE_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("input", QwenImageTextInputsStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ("denoise", QwenImageDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -## 1.2 QwenImage/inpaint - -#### QwenImage/inpaint vae encoder -QwenImageInpaintVaeEncoderBlocks = InsertableDict( - [ - ( - "preprocess", - QwenImageInpaintProcessImagesInputStep, - ), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents - ] -) - - -class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintVaeEncoderBlocks.values() - block_names = QwenImageInpaintVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return ( - "This step is used for processing image and mask inputs for inpainting tasks. It:\n" - " - Resizes the image to the target size, based on `height` and `width`.\n" - " - Processes and updates `image` and `mask_image`.\n" - " - Creates `image_latents`." - ) - - -#### QwenImage/inpaint inputs -QwenImageInpaintInputBlocks = InsertableDict( - [ - ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings - ( - "additional_inputs", - QwenImageInputsDynamicStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] - ), - ), - ] -) - - -class QwenImageInpaintInputStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintInputBlocks.values() - block_names = QwenImageInpaintInputBlocks.keys() - - @property - def description(self): - return "Input step that prepares the inputs for the inpainting denoising step. It:\n" - " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n" - " - update height/width based `image_latents`, patchify `image_latents`." - - -# QwenImage/inpaint prepare latents -QwenImageInpaintPrepareLatentsBlocks = InsertableDict( - [ - ("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()), - ("create_mask_latents", QwenImageCreateMaskLatentsStep()), - ] -) - - -class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintPrepareLatentsBlocks.values() - block_names = QwenImageInpaintPrepareLatentsBlocks.keys() - - @property - def description(self) -> str: - return ( - "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" - " - Add noise to the image latents to create the latents input for the denoiser.\n" - " - Create the pachified latents `mask` based on the processedmask image.\n" - ) - - -#### QwenImage/inpaint decode -QwenImageInpaintDecodeBlocks = InsertableDict( - [ - ("decode", QwenImageDecoderStep()), - ("postprocess", QwenImageInpaintProcessImagesOutputStep()), - ] -) - - -class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintDecodeBlocks.values() - block_names = QwenImageInpaintDecodeBlocks.keys() - - @property - def description(self): - return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image." - - -#### QwenImage/inpaint presets -INPAINT_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("vae_encoder", QwenImageInpaintVaeEncoderStep()), - ("input", QwenImageInpaintInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ("denoise", QwenImageInpaintDenoiseStep()), - ("decode", QwenImageInpaintDecodeStep()), - ] -) - - -## 1.3 QwenImage/img2img - -#### QwenImage/img2img vae encoder -QwenImageImg2ImgVaeEncoderBlocks = InsertableDict( - [ - ("preprocess", QwenImageProcessImagesInputStep()), - ("encode", QwenImageVaeEncoderDynamicStep()), - ] -) - - -class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - - block_classes = QwenImageImg2ImgVaeEncoderBlocks.values() - block_names = QwenImageImg2ImgVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return "Vae encoder step that preprocess andencode the image inputs into their latent representations." - - -#### QwenImage/img2img inputs -QwenImageImg2ImgInputBlocks = InsertableDict( - [ - ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings - ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])), - ] -) - - -class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageImg2ImgInputBlocks.values() - block_names = QwenImageImg2ImgInputBlocks.keys() - - @property - def description(self): - return "Input step that prepares the inputs for the img2img denoising step. It:\n" - " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" - " - update height/width based `image_latents`, patchify `image_latents`." - - -#### QwenImage/img2img presets -IMAGE2IMAGE_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("vae_encoder", QwenImageImg2ImgVaeEncoderStep()), - ("input", QwenImageImg2ImgInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ("denoise", QwenImageDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -## 1.4 QwenImage/controlnet - -#### QwenImage/controlnet presets -CONTROLNET_BLOCKS = InsertableDict( - [ - ("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image - ("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet - ( - "controlnet_before_denoise", - QwenImageControlNetBeforeDenoiserStep(), - ), # before denoise step (after set_timesteps step) - ( - "controlnet_denoise_loop_before", - QwenImageLoopBeforeDenoiserControlNet(), - ), # controlnet loop step (insert before the denoiseloop_denoiser) - ] -) - - -## 1.5 QwenImage/auto encoders - - -#### for inpaint and img2img tasks -class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] - block_names = ["inpaint", "img2img"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" - + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n" - + " - if `mask_image` or `image` is not provided, step will be skipped." - ) - - -# for controlnet tasks -class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): - block_classes = [QwenImageControlNetVaeEncoderStep] - block_names = ["controlnet"] - block_trigger_inputs = ["control_image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n" - + " - if `control_image` is not provided, step will be skipped." - ) - - -## 1.6 QwenImage/auto inputs - - -# text2image/inpaint/img2img -class QwenImageAutoInputStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep] - block_names = ["inpaint", "img2img", "text2image"] - block_trigger_inputs = ["processed_mask_image", "image_latents", None] - - @property - def description(self): - return ( - "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" - " This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n" - + " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n" - + " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n" - ) - - -# controlnet -class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks): - block_classes = [QwenImageControlNetInputsStep] - block_names = ["controlnet"] - block_trigger_inputs = ["control_image_latents"] - - @property - def description(self): - return ( - "Controlnet input step that prepare the control_image_latents input.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n" - + " - if `control_image_latents` is not provided, step will be skipped." - ) - - -## 1.7 QwenImage/auto before denoise step -# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step - -# QwenImage/text2image before denoise -QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ] -) - - -class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values() - block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task." - - -# QwenImage/inpaint before denoise -QwenImageInpaintBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ] -) - - -class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageInpaintBeforeDenoiseBlocks.values() - block_names = QwenImageInpaintBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." - - -# QwenImage/img2img before denoise -QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()), - ("prepare_rope_inputs", QwenImageRoPEInputsStep()), - ] -) - - -class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values() - block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." - - -# auto before_denoise step for text2image, inpaint, img2img tasks -class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [ - QwenImageInpaintBeforeDenoiseStep, - QwenImageImg2ImgBeforeDenoiseStep, - QwenImageText2ImageBeforeDenoiseStep, - ] - block_names = ["inpaint", "img2img", "text2image"] - block_trigger_inputs = ["processed_mask_image", "image_latents", None] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" - + "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n" - + " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n" - + " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n" - ) - - -# auto before_denoise step for controlnet tasks -class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [QwenImageControlNetBeforeDenoiserStep] - block_names = ["controlnet"] - block_trigger_inputs = ["control_image_latents"] - - @property - def description(self): - return ( - "Controlnet before denoise step that prepare the controlnet input.\n" - + "This is an auto pipeline block.\n" - + " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n" - + " - if `control_image_latents` is not provided, step will be skipped." - ) - - -## 1.8 QwenImage/auto denoise - - -# auto denoise step for controlnet tasks: works for all tasks with controlnet -class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["mask", None] - - @property - def description(self): - return ( - "Controlnet step during the denoising process. \n" - " This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n" - + " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n" - + " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n" - ) - - -# auto denoise step for everything: works for all tasks with or without controlnet -class QwenImageAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [ - QwenImageControlNetAutoDenoiseStep, - QwenImageInpaintDenoiseStep, - QwenImageDenoiseStep, - ] - block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"] - block_trigger_inputs = ["control_image_latents", "mask", None] - - @property - def description(self): - return ( - "Denoise step that iteratively denoise the latents. \n" - " This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n" - + " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n" - + " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n" - + " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n" - ) - - -## 1.9 QwenImage/auto decode -# auto decode step for inpaint and text2image tasks - - -class QwenImageAutoDecodeStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep] - block_names = ["inpaint_decode", "decode"] - block_trigger_inputs = ["mask", None] - - @property - def description(self): - return ( - "Decode step that decode the latents into images. \n" - " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n" - + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" - + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n" - ) - - -class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = [ - QwenImageAutoInputStep, - QwenImageOptionalControlNetInputStep, - QwenImageAutoBeforeDenoiseStep, - QwenImageOptionalControlNetBeforeDenoiseStep, - QwenImageAutoDenoiseStep, - ] - block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"] - - @property - def description(self): - return ( - "Core step that performs the denoising process. \n" - + " - `QwenImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n" - + " - `QwenImageOptionalControlNetInputStep` (controlnet_input) prepares the controlnet input.\n" - + " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" - + " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n" - + " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n" - + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n" - + " - for image-to-image generation, you need to provide `image_latents`\n" - + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" - + " - to run the controlnet workflow, you need to provide `control_image_latents`\n" - + " - for text-to-image generation, all you need to provide is prompt embeddings" - ) - - -## 1.10 QwenImage/auto block & presets -AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageTextEncoderStep()), - ("vae_encoder", QwenImageAutoVaeEncoderStep()), - ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), - ("denoise", QwenImageCoreDenoiseStep()), - ("decode", QwenImageAutoDecodeStep()), - ] -) - - -class QwenImageAutoBlocks(SequentialPipelineBlocks): - model_name = "qwenimage" - - block_classes = AUTO_BLOCKS.values() - block_names = AUTO_BLOCKS.keys() - - @property - def description(self): - return ( - "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" - + "- for image-to-image generation, you need to provide `image`\n" - + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" - + "- to run the controlnet workflow, you need to provide `control_image`\n" - + "- for text-to-image generation, all you need to provide is `prompt`" - ) - - -# 2. QwenImage-Edit - -## 2.1 QwenImage-Edit/edit - -#### QwenImage-Edit/edit vl encoder: take both image and text prompts -QwenImageEditVLEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditResizeDynamicStep()), - ("encode", QwenImageEditTextEncoderStep()), - ] -) - - -class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditVLEncoderBlocks.values() - block_names = QwenImageEditVLEncoderBlocks.keys() - - @property - def description(self) -> str: - return "QwenImage-Edit VL encoder step that encode the image an text prompts together." - - -#### QwenImage-Edit/edit vae encoder -QwenImageEditVaeEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step - ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents - ] -) - - -class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditVaeEncoderBlocks.values() - block_names = QwenImageEditVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return "Vae encoder step that encode the image inputs into their latent representations." - - -#### QwenImage-Edit/edit input -QwenImageEditInputBlocks = InsertableDict( - [ - ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings - ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])), - ] -) - - -class QwenImageEditInputStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditInputBlocks.values() - block_names = QwenImageEditInputBlocks.keys() - - @property - def description(self): - return "Input step that prepares the inputs for the edit denoising step. It:\n" - " - make sure the text embeddings have consistent batch size as well as the additional inputs: \n" - " - `image_latents`.\n" - " - update height/width based `image_latents`, patchify `image_latents`." - - -#### QwenImage/edit presets -EDIT_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditVLEncoderStep()), - ("vae_encoder", QwenImageEditVaeEncoderStep()), - ("input", QwenImageEditInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ("denoise", QwenImageEditDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -## 2.2 QwenImage-Edit/edit inpaint - -#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step -QwenImageEditInpaintVaeEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image - ( - "preprocess", - QwenImageInpaintProcessImagesInputStep, - ), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs - ( - "encode", - QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"), - ), # processed_image -> image_latents - ] -) - - -class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditInpaintVaeEncoderBlocks.values() - block_names = QwenImageEditInpaintVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return ( - "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n" - " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n" - " - process the resized image and mask image.\n" - " - create image latents." - ) - - -#### QwenImage-Edit/edit inpaint presets -EDIT_INPAINT_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditVLEncoderStep()), - ("vae_encoder", QwenImageEditInpaintVaeEncoderStep()), - ("input", QwenImageInpaintInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ("denoise", QwenImageEditInpaintDenoiseStep()), - ("decode", QwenImageInpaintDecodeStep()), - ] -) - - -## 2.3 QwenImage-Edit/auto encoders - - -class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [ - QwenImageEditInpaintVaeEncoderStep, - QwenImageEditVaeEncoderStep, - ] - block_names = ["edit_inpaint", "edit"] - block_trigger_inputs = ["mask_image", "image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations. \n" - " This is an auto pipeline block that works for edit and edit_inpaint tasks.\n" - + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n" - + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n" - + " - if `mask_image` or `image` is not provided, step will be skipped." - ) - - -## 2.4 QwenImage-Edit/auto inputs -class QwenImageEditAutoInputStep(AutoPipelineBlocks): - block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep] - block_names = ["edit_inpaint", "edit"] - block_trigger_inputs = ["processed_mask_image", "image_latents"] - - @property - def description(self): - return ( - "Input step that prepares the inputs for the edit denoising step.\n" - + " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n" - + " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n" - + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped." - ) - - -## 2.5 QwenImage-Edit/auto before denoise -# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step - -#### QwenImage-Edit/edit before denoise -QwenImageEditBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ] -) - - -class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditBeforeDenoiseBlocks.values() - block_names = QwenImageEditBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." - - -#### QwenImage-Edit/edit inpaint before denoise -QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()), - ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ] -) - - -class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values() - block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task." - - -# auto before_denoise step for edit and edit_inpaint tasks -class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks): - model_name = "qwenimage-edit" - block_classes = [ - QwenImageEditInpaintBeforeDenoiseStep, - QwenImageEditBeforeDenoiseStep, - ] - block_names = ["edit_inpaint", "edit"] - block_trigger_inputs = ["processed_mask_image", "image_latents"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" - + "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n" - + " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" - + " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped." - ) - - -## 2.6 QwenImage-Edit/auto denoise - - -class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks): - model_name = "qwenimage-edit" - - block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep] - block_names = ["inpaint_denoise", "denoise"] - block_trigger_inputs = ["processed_mask_image", "image_latents"] - - @property - def description(self): - return ( - "Denoise step that iteratively denoise the latents. \n" - + "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n" - + " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n" - + " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n" - + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped." - ) - - -## 2.7 QwenImage-Edit/auto blocks & presets - - -class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage-edit" - block_classes = [ - QwenImageEditAutoInputStep, - QwenImageEditAutoBeforeDenoiseStep, - QwenImageEditAutoDenoiseStep, - ] - block_names = ["input", "before_denoise", "denoise"] - - @property - def description(self): - return ( - "Core step that performs the denoising process. \n" - + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n" - + " - `QwenImageEditAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" - + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" - + "This step support edit (img2img) and edit inpainting workflow for QwenImage Edit:\n" - + " - When `processed_mask_image` is provided, it will be used for edit inpainting task.\n" - + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n" - ) - - -EDIT_AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditVLEncoderStep()), - ("vae_encoder", QwenImageEditAutoVaeEncoderStep()), - ("denoise", QwenImageEditCoreDenoiseStep()), - ("decode", QwenImageAutoDecodeStep()), - ] -) - - -class QwenImageEditAutoBlocks(SequentialPipelineBlocks): - model_name = "qwenimage-edit" - block_classes = EDIT_AUTO_BLOCKS.values() - block_names = EDIT_AUTO_BLOCKS.keys() - - @property - def description(self): - return ( - "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n" - + "- for edit (img2img) generation, you need to provide `image`\n" - + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" - ) - - -#################### QwenImage Edit Plus ##################### - -# 3. QwenImage-Edit Plus - -## 3.1 QwenImage-Edit Plus / edit - -#### QwenImage-Edit Plus vl encoder: take both image and text prompts -QwenImageEditPlusVLEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditPlusResizeDynamicStep()), - ("encode", QwenImageEditPlusTextEncoderStep()), - ] -) - - -class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditPlusVLEncoderBlocks.values() - block_names = QwenImageEditPlusVLEncoderBlocks.keys() - - @property - def description(self) -> str: - return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together." - - -#### QwenImage-Edit Plus vae encoder -QwenImageEditPlusVaeEncoderBlocks = InsertableDict( - [ - ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step - ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents - ] -) - - -class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditPlusVaeEncoderBlocks.values() - block_names = QwenImageEditPlusVaeEncoderBlocks.keys() - - @property - def description(self) -> str: - return "Vae encoder step that encode the image inputs into their latent representations." - - -#### QwenImage Edit Plus presets -EDIT_PLUS_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditPlusVLEncoderStep()), - ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), - ("input", QwenImageEditInputStep()), - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), - ("denoise", QwenImageEditDenoiseStep()), - ("decode", QwenImageDecodeStep()), - ] -) - - -# auto before_denoise step for edit tasks -class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = [QwenImageEditBeforeDenoiseStep] - block_names = ["edit"] - block_trigger_inputs = ["image_latents"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" - + "This is an auto pipeline block that works for edit (img2img) task.\n" - + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" - + " - if `image_latents` is not provided, step will be skipped." - ) - - -## 3.2 QwenImage-Edit Plus/auto encoders - - -class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [ - QwenImageEditPlusVaeEncoderStep, - ] - block_names = ["edit"] - block_trigger_inputs = ["image"] - - @property - def description(self): - return ( - "Vae encoder step that encode the image inputs into their latent representations. \n" - " This is an auto pipeline block that works for edit task.\n" - + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n" - + " - if `image` is not provided, step will be skipped." - ) - - -## 3.3 QwenImage-Edit/auto blocks & presets - - -class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = [ - QwenImageEditAutoInputStep, - QwenImageEditPlusAutoBeforeDenoiseStep, - QwenImageEditAutoDenoiseStep, - ] - block_names = ["input", "before_denoise", "denoise"] - - @property - def description(self): - return ( - "Core step that performs the denoising process. \n" - + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n" - + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" - + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" - + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n" - + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n" - ) - - -EDIT_PLUS_AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", QwenImageEditPlusVLEncoderStep()), - ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()), - ("denoise", QwenImageEditPlusCoreDenoiseStep()), - ("decode", QwenImageAutoDecodeStep()), - ] -) - - -class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): - model_name = "qwenimage-edit-plus" - block_classes = EDIT_PLUS_AUTO_BLOCKS.values() - block_names = EDIT_PLUS_AUTO_BLOCKS.keys() - - @property - def description(self): - return ( - "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n" - + "- for edit (img2img) generation, you need to provide `image`\n" - ) - - -# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus - - -ALL_BLOCKS = { - "text2image": TEXT2IMAGE_BLOCKS, - "img2img": IMAGE2IMAGE_BLOCKS, - "edit": EDIT_BLOCKS, - "edit_inpaint": EDIT_INPAINT_BLOCKS, - "edit_plus": EDIT_PLUS_BLOCKS, - "inpaint": INPAINT_BLOCKS, - "controlnet": CONTROLNET_BLOCKS, - "auto": AUTO_BLOCKS, - "edit_auto": EDIT_AUTO_BLOCKS, - "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS, -} diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py new file mode 100644 index 000000000000..5837799d3431 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -0,0 +1,1214 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + QwenImageControlNetBeforeDenoiserStep, + QwenImageCreateMaskLatentsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageRoPEInputsStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageControlNetDenoiseStep, + QwenImageDenoiseStep, + QwenImageInpaintControlNetDenoiseStep, + QwenImageInpaintDenoiseStep, +) +from .encoders import ( + QwenImageControlNetVaeEncoderStep, + QwenImageInpaintProcessImagesInputStep, + QwenImageProcessImagesInputStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageControlNetInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): + """ + Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block. + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextEncoderStep()] + block_names = ["text_encoder"] + block_trigger_inputs = ["prompt"] + + @property + def description(self) -> str: + return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block." + " - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided." + " - if `prompt` is not provided, step will be skipped." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for inpainting tasks. It: + - Resizes the image to the target size, based on `height` and `width`. + - Processes and updates `image` and `mask_image`. + - Creates `image_latents`. + + Components: + image_mask_processor (`InpaintProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for inpainting tasks. It:\n" + " - Resizes the image to the target size, based on `height` and `width`.\n" + " - Processes and updates `image` and `mask_image`.\n" + " - Creates `image_latents`." + ) + + +# auto_docstring +class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage" + + block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()] + block_names = ["preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + +class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] + block_names = ["inpaint", "img2img"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n" + + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n" + + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# optional controlnet vae encoder +# auto_docstring +class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block. + - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided. + - if `control_image` is not provided, step will be skipped. + + Components: + vae (`AutoencoderKLQwenImage`) controlnet (`QwenImageControlNetModel`) control_image_processor + (`VaeImageProcessor`) + + Inputs: + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + + block_classes = [QwenImageControlNetVaeEncoderStep] + block_names = ["controlnet"] + block_trigger_inputs = ["control_image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + + "This is an auto pipeline block.\n" + + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n" + + " - if `control_image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep()] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# auto_docstring +class QwenImageInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the inpainting denoising step. It: + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return "Input step that prepares the inputs for the inpainting denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +# assemble prepare latents steps +# auto_docstring +class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the pachified latents `mask` based on the processedmask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the pachified latents `mask` based on the processedmask image.\n" + ) + + +# assemble denoising steps + + +# Qwen Image (text2image) +# auto_docstring +class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (inpainting) +# auto_docstring +class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (image2image) +# auto_docstring +class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (text2image) with controlnet +# auto_docstring +class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (inpainting) with controlnet +# auto_docstring +class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageInpaintInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageInpaintPrepareLatentsStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageInpaintControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image (image2image) with controlnet +# auto_docstring +class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) controlnet + (`QwenImageControlNetModel`) guider (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + control_image_latents (`Tensor`): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage" + block_classes = [ + QwenImageImg2ImgInputStep(), + QwenImageControlNetInputsStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImagePrepareLatentsWithStrengthStep(), + QwenImageRoPEInputsStep(), + QwenImageControlNetBeforeDenoiserStep(), + QwenImageControlNetDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "controlnet_input", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "prepare_rope_inputs", + "controlnet_before_denoise", + "controlnet_denoise", + "after_denoise", + ] + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto denoise step for QwenImage +class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): + block_classes = [ + QwenImageCoreDenoiseStep, + QwenImageInpaintCoreDenoiseStep, + QwenImageImg2ImgCoreDenoiseStep, + QwenImageControlNetCoreDenoiseStep, + QwenImageControlNetInpaintCoreDenoiseStep, + QwenImageControlNetImg2ImgCoreDenoiseStep, + ] + block_names = [ + "text2image", + "inpaint", + "img2img", + "controlnet_text2image", + "controlnet_inpaint", + "controlnet_img2img", + ] + block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"] + default_block_name = "text2image" + + def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None): + if control_image_latents is not None: + if processed_mask_image is not None: + return "controlnet_inpaint" + elif image_latents is not None: + return "controlnet_img2img" + else: + return "controlnet_text2image" + else: + if processed_mask_image is not None: + return "inpaint" + elif image_latents is not None: + return "img2img" + else: + return "text2image" + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n" + + " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n" + + " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n" + + " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n" + + " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n" + + " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n" + + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n" + + " - for image-to-image generation, you need to provide `image_latents`\n" + + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n" + + " - to run the controlnet workflow, you need to provide `control_image_latents`\n" + + " - for text-to-image generation, all you need to provide is prompt embeddings" + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# standard decode step works for most tasks except for inpaint +# auto_docstring +class QwenImageDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +# auto_docstring +class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask + overally to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image." + + +# Auto decode step for QwenImage +class QwenImageAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images. \n" + " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n" + + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n" + ) + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageAutoTextEncoderStep()), + ("vae_encoder", QwenImageAutoVaeEncoderStep()), + ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), + ("denoise", QwenImageAutoCoreDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage. + - for image-to-image generation, you need to provide `image` + - for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`. + - to run the controlnet workflow, you need to provide `control_image` + - for text-to-image generation, all you need to provide is `prompt` + + Components: + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`): + The tokenizer to use guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) controlnet (`QwenImageControlNetModel`) + control_image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + mask_image (`Image`, *optional*): + Mask image for inpainting. + image (`Union[Image, List]`, *optional*): + Reference image(s) for denoising. Can be a single image or list of images. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + image_latents (`Tensor`, *optional*): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_image_latents (`Tensor`, *optional*): + The control image latents to use for the denoising process. Can be generated in controlnet vae encoder + step. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + + model_name = "qwenimage" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" + + "- for image-to-image generation, you need to provide `image`\n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n" + + "- to run the controlnet workflow, you need to provide `control_image`\n" + + "- for text-to-image generation, all you need to provide is `prompt`" + ) + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py new file mode 100644 index 000000000000..e1e5c4335481 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -0,0 +1,790 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam +from .before_denoise import ( + QwenImageCreateMaskLatentsStep, + QwenImageEditRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImagePrepareLatentsWithStrengthStep, + QwenImageSetTimestepsStep, + QwenImageSetTimestepsWithStrengthStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageInpaintProcessImagesOutputStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, + QwenImageEditInpaintDenoiseStep, +) +from .encoders import ( + QwenImageEditInpaintProcessImagesInputStep, + QwenImageEditProcessImagesInputStep, + QwenImageEditResizeStep, + QwenImageEditTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Edit VL encoder step that encode the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`List`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit VL encoder step that encode the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +# auto_docstring +class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# Edit Inpaint VAE encoder +# auto_docstring +class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It: + - resize the image for target area (1024 * 1024) while maintaining the aspect ratio. + - process the resized image and mask image. + - create image latents. + + Components: + image_resize_processor (`VaeImageProcessor`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + mask_image (`Image`): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + processed_mask_image (`Tensor`): + The processed mask image + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditResizeStep(), + QwenImageEditInpaintProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n" + " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n" + " - process the resized image and mask image.\n" + " - create image latents." + ) + + +# Auto VAE encoder +class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["mask_image", "image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n" + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n" + " - if `mask_image` or `image` is not provided, step will be skipped." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageEditInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# auto_docstring +class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit inpaint denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified and + batch-expanded) + processed_mask_image (`Tensor`): + The processed mask image (batch-expanded) + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageAdditionalInputsStep( + additional_batch_inputs=[ + InputParam(name="processed_mask_image", type_hint=torch.Tensor, description="The processed mask image") + ] + ), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit inpaint denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# assemble prepare latents steps +# auto_docstring +class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the patchified latents `mask` based on the processed mask image. + + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (Can be + generated from vae encoder and updated in input step.) + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + latents (`Tensor`): + The scaled noisy latents to use for inpainting/image-to-image denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] + block_names = ["add_noise_to_latents", "create_mask_latents"] + + @property + def description(self) -> str: + return ( + "This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n" + " - Add noise to the image latents to create the latents input for the denoiser.\n" + " - Create the patchified latents `mask` based on the processed mask image.\n" + ) + + +# Qwen Image Edit (image2image) core denoise step +# auto_docstring +class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit (img2img) task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Qwen Image Edit (inpainting) core denoise step +# auto_docstring +class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit inpaint task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsWithStrengthStep(), + QwenImageEditInpaintPrepareLatentsStep(), + QwenImageEditRoPEInputsStep(), + QwenImageEditInpaintDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_inpaint_latents", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit edit inpaint task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# Auto core denoise step for QwenImage Edit +class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): + model_name = "qwenimage-edit" + block_classes = [ + QwenImageEditInpaintCoreDenoiseStep, + QwenImageEditCoreDenoiseStep, + ] + block_names = ["edit_inpaint", "edit"] + block_trigger_inputs = ["processed_mask_image", "image_latents"] + default_block_name = "edit" + + def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]: + if processed_mask_image is not None: + return "edit_inpaint" + elif image_latents is not None: + return "edit" + return None + + @property + def description(self): + return ( + "Auto core denoising step that selects the appropriate workflow based on inputs.\n" + " - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n" + " - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n" + "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit." + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# Decode step (standard) +# auto_docstring +class QwenImageEditDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image." + + +# Inpaint decode step +# auto_docstring +class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask + overlay to the original image. + + Components: + vae (`AutoencoderKLQwenImage`) image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit" + block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image." + + +# Auto decode step +class QwenImageEditAutoDecodeStep(AutoPipelineBlocks): + block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep] + block_names = ["inpaint_decode", "decode"] + block_trigger_inputs = ["mask", None] + + @property + def description(self): + return ( + "Decode step that decode the latents into images.\n" + "This is an auto pipeline block.\n" + " - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n" + " - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n" + ) + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditVLEncoderStep()), + ("vae_encoder", QwenImageEditAutoVaeEncoderStep()), + ("denoise", QwenImageEditAutoCoreDenoiseStep()), + ("decode", QwenImageEditAutoDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit. + - for edit (img2img) generation, you need to provide `image` + - for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide + `padding_mask_crop` + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae + (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) pachifier (`QwenImagePachifier`) scheduler + (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + mask_image (`Image`, *optional*): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + processed_mask_image (`Tensor`, *optional*): + The processed mask image + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + mask_overlay_kwargs (`Dict`, *optional*): + The kwargs for the postprocess step to apply the mask overlay. generated in + InpaintProcessImagesInputStep. + + Outputs: + images (`List`): + Generated images. + """ + + model_name = "qwenimage-edit" + block_classes = EDIT_AUTO_BLOCKS.values() + block_names = EDIT_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n" + "- for edit (img2img) generation, you need to provide `image`\n" + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n" + ) + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py new file mode 100644 index 000000000000..37656cef5d76 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -0,0 +1,409 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + QwenImageEditPlusRoPEInputsStep, + QwenImagePrepareLatentsStep, + QwenImageSetTimestepsStep, +) +from .decoders import ( + QwenImageAfterDenoiseStep, + QwenImageDecoderStep, + QwenImageProcessImagesOutputStep, +) +from .denoise import ( + QwenImageEditDenoiseStep, +) +from .encoders import ( + QwenImageEditPlusProcessImagesInputStep, + QwenImageEditPlusResizeStep, + QwenImageEditPlusTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageEditPlusAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(), + QwenImageEditPlusTextEncoderStep(), + ] + block_names = ["resize", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# auto_docstring +class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): + """ + VAE encoder step that encodes image inputs into latent representations. + Each image is resized independently based on its own aspect ratio to 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + Images resized to 1024x1024 target area for VAE encoding + resized_cond_image (`List`): + Images resized to 384x384 target area for VL text encoding + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusResizeStep(), + QwenImageEditPlusProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + ] + block_names = ["resize", "preprocess", "encode"] + + @property + def description(self) -> str: + return ( + "VAE encoder step that encodes image inputs into latent representations.\n" + "Each image is resized independently based on its own aspect ratio to 1024x1024 target area." + ) + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the Edit Plus denoising step. It: + - Standardizes text embeddings batch size. + - Processes list of image latents: patchifies, concatenates along dim=1, expands batch. + - Outputs lists of image_height/image_width for RoPE calculation. + - Defaults height/width from last image in the list. + + Components: + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified, + concatenated, and batch-expanded) + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageEditPlusAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the Edit Plus denoising step. It:\n" + " - Standardizes text embeddings batch size.\n" + " - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n" + " - Outputs lists of image_height/image_width for RoPE calculation.\n" + " - Defaults height/width from last image in the list." + ) + + +# Qwen Image Edit Plus (image2image) core denoise step +# auto_docstring +class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit Plus edit (img2img) task. + + Components: + pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditPlusInputStep(), + QwenImagePrepareLatentsStep(), + QwenImageSetTimestepsStep(), + QwenImageEditPlusRoPEInputsStep(), + QwenImageEditDenoiseStep(), + QwenImageAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. DECODE +# ==================== + + +# auto_docstring +class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocesses the generated image. + + Components: + vae (`AutoencoderKLQwenImage`) image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step and unpacked in the after denoise + step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "qwenimage-edit-plus" + block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] + block_names = ["decode", "postprocess"] + + @property + def description(self): + return "Decode step that decodes the latents to images and postprocesses the generated image." + + +# ==================== +# 5. AUTO BLOCKS & PRESETS +# ==================== + +EDIT_PLUS_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), + ("denoise", QwenImageEditPlusCoreDenoiseStep()), + ("decode", QwenImageEditPlusDecodeStep()), + ] +) + + +# auto_docstring +class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus. + - `image` is required input (can be single image or list of images). + - Each image is resized independently based on its own aspect ratio. + - VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) pachifier (`QwenImagePachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) + transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + + model_name = "qwenimage-edit-plus" + block_classes = EDIT_PLUS_AUTO_BLOCKS.values() + block_names = EDIT_PLUS_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n" + "- `image` is required input (can be single image or list of images).\n" + "- Each image is resized independently based on its own aspect ratio.\n" + "- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area." + ) + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py new file mode 100644 index 000000000000..fdfeab048835 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -0,0 +1,368 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + QwenImageLayeredPrepareLatentsStep, + QwenImageLayeredRoPEInputsStep, + QwenImageLayeredSetTimestepsStep, +) +from .decoders import ( + QwenImageLayeredAfterDenoiseStep, + QwenImageLayeredDecoderStep, +) +from .denoise import ( + QwenImageLayeredDenoiseStep, +) +from .encoders import ( + QwenImageEditProcessImagesInputStep, + QwenImageLayeredGetImagePromptStep, + QwenImageLayeredPermuteLatentsStep, + QwenImageLayeredResizeStep, + QwenImageTextEncoderStep, + QwenImageVaeEncoderStep, +) +from .inputs import ( + QwenImageLayeredAdditionalInputsStep, + QwenImageTextInputsStep, +) + + +logger = logging.get_logger(__name__) + + +# ==================== +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): + """ + QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not + provided. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + resized_image (`List`): + The resized images + prompt (`str`): + The prompt or prompts to guide image generation. If not provided, updated using image caption + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageLayeredGetImagePromptStep(), + QwenImageTextEncoderStep(), + ] + block_names = ["resize", "get_image_prompt", "encode"] + + @property + def description(self) -> str: + return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided." + + +# ==================== +# 2. VAE ENCODER +# ==================== + + +# Edit VAE encoder +# auto_docstring +class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + image_resize_processor (`VaeImageProcessor`) image_processor (`VaeImageProcessor`) vae + (`AutoencoderKLQwenImage`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`Tensor`): + The processed image + image_latents (`Tensor`): + The latent representation of the input image. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredResizeStep(), + QwenImageEditProcessImagesInputStep(), + QwenImageVaeEncoderStep(), + QwenImageLayeredPermuteLatentsStep(), + ] + block_names = ["resize", "preprocess", "encode", "permute"] + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +# ==================== +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# ==================== + + +# assemble input steps +# auto_docstring +class QwenImageLayeredInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the layered denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + + Outputs: + batch_size (`int`): + The batch size of the prompt embeddings + dtype (`dtype`): + The data type of the prompt embeddings + prompt_embeds (`Tensor`): + The prompt embeddings. (batch-expanded) + prompt_embeds_mask (`Tensor`): + The encoder attention mask. (batch-expanded) + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. (batch-expanded) + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. (batch-expanded) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + if not provided, updated to image height + width (`int`): + if not provided, updated to image width + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. (patchified + with layered pachifier and batch-expanded) + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageTextInputsStep(), + QwenImageLayeredAdditionalInputsStep(), + ] + block_names = ["text_inputs", "additional_inputs"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the layered denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs.\n" + " - update height/width based `image_latents`, patchify `image_latents`." + ) + + +# Qwen Image Layered (image2image) core denoise step +# auto_docstring +class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Layered img2img task. + + Components: + pachifier (`QwenImageLayeredPachifier`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + image_latents (`Tensor`): + image latents used to guide the image generation. Can be generated from vae_encoder step. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "qwenimage-layered" + block_classes = [ + QwenImageLayeredInputStep(), + QwenImageLayeredPrepareLatentsStep(), + QwenImageLayeredSetTimestepsStep(), + QwenImageLayeredRoPEInputsStep(), + QwenImageLayeredDenoiseStep(), + QwenImageLayeredAfterDenoiseStep(), + ] + block_names = [ + "input", + "prepare_latents", + "set_timesteps", + "prepare_rope_inputs", + "denoise", + "after_denoise", + ] + + @property + def description(self): + return "Core denoising workflow for QwenImage-Layered img2img task." + + @property + def outputs(self): + return [ + OutputParam.template("latents"), + ] + + +# ==================== +# 4. AUTO BLOCKS & PRESETS +# ==================== + +LAYERED_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageLayeredTextEncoderStep()), + ("vae_encoder", QwenImageLayeredVaeEncoderStep()), + ("denoise", QwenImageLayeredCoreDenoiseStep()), + ("decode", QwenImageLayeredDecoderStep()), + ] +) + + +# auto_docstring +class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for layered denoising tasks using QwenImage-Layered. + + Components: + image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor + (`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`) + image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`) pachifier (`QwenImageLayeredPachifier`) + scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`) + + Inputs: + image (`Union[Image, List]`): + Reference image(s) for denoising. Can be a single image or list of images. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`None`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`List`): + Generated images. + """ + + model_name = "qwenimage-layered" + block_classes = LAYERED_AUTO_BLOCKS.values() + block_names = LAYERED_AUTO_BLOCKS.keys() + + @property + def description(self): + return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered." + + @property + def outputs(self): + return [ + OutputParam.template("images"), + ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py index 59e1a13a5db2..892435989d00 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -90,6 +90,88 @@ def unpack_latents(self, latents, height, width, vae_scale_factor=8): return latents +class QwenImageLayeredPachifier(ConfigMixin): + """ + A class to pack and unpack latents for QwenImage Layered. + + Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W). + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = 2): + super().__init__() + + def pack_latents(self, latents): + """ + Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4). + """ + + if latents.ndim != 5: + raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}") + + batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape + patch_size = self.config.patch_size + + if latent_height % patch_size != 0 or latent_width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}" + ) + + latents = latents.view( + batch_size, + layers, + num_channels_latents, + latent_height // patch_size, + patch_size, + latent_width // patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape( + batch_size, + layers * (latent_height // patch_size) * (latent_width // patch_size), + num_channels_latents * patch_size * patch_size, + ) + return latents + + def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8): + """ + Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W). + """ + + if latents.ndim != 3: + raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}") + + batch_size, _, channels = latents.shape + patch_size = self.config.patch_size + + height = patch_size * (int(height) // (vae_scale_factor * patch_size)) + width = patch_size * (int(width) // (vae_scale_factor * patch_size)) + + latents = latents.view( + batch_size, + layers + 1, + height // patch_size, + width // patch_size, + channels // (patch_size * patch_size), + patch_size, + patch_size, + ) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + latents = latents.reshape( + batch_size, + layers + 1, + channels // (patch_size * patch_size), + height, + width, + ) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + + return latents + + class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): """ A ModularPipeline for QwenImage. @@ -203,3 +285,13 @@ class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline): """ default_blocks_name = "QwenImageEditPlusAutoBlocks" + + +class QwenImageLayeredModularPipeline(QwenImageModularPipeline): + """ + A ModularPipeline for QwenImage-Layered. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageLayeredAutoBlocks" diff --git a/src/diffusers/modular_pipelines/qwenimage/node_utils.py b/src/diffusers/modular_pipelines/qwenimage/node_utils.py deleted file mode 100644 index 3230ece68abc..000000000000 --- a/src/diffusers/modular_pipelines/qwenimage/node_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# mellon nodes -QwenImage_NODE_TYPES_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - "vae", - ], - "outputs": [ - "controlnet_out", - ], - "block_names": ["controlnet_vae_encoder"], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - "controlnet", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - }, -} diff --git a/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py b/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py new file mode 100644 index 000000000000..8e7beb555760 --- /dev/null +++ b/src/diffusers/modular_pipelines/qwenimage/prompt_templates.py @@ -0,0 +1,121 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prompt templates for QwenImage pipelines. + +This module centralizes all prompt templates used across different QwenImage pipeline variants: +- QwenImage (base): Text-only encoding for text-to-image generation +- QwenImage Edit: VL encoding with single image for image editing +- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing +- QwenImage Layered: Auto-captioning for image decomposition +""" + +# ============================================ +# QwenImage Base (text-only encoding) +# ============================================ +# Used for text-to-image generation where only text prompt is encoded + +QWENIMAGE_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the image by detailing the color, shape, size, texture, quantity, text, " + "spatial relationships of the objects and background:<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34 + + +# ============================================ +# QwenImage Edit (VL encoding with single image) +# ============================================ +# Used for single-image editing where both image and text are encoded together + +QWENIMAGE_EDIT_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n" + "<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Edit Plus (VL encoding with multiple images) +# ============================================ +# Used for multi-reference editing where multiple images and text are encoded together +# The img_template is used to format each image in the prompt + +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = ( + "<|im_start|>system\n" + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate.<|im_end|>\n" + "<|im_start|>user\n{}<|im_end|>\n" + "<|im_start|>assistant\n" +) +QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" +QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64 + + +# ============================================ +# QwenImage Layered (auto-captioning) +# ============================================ +# Used for image decomposition where the VL model generates a caption from the input image +# if no prompt is provided. These prompts instruct the model to describe the image in detail. + +QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# Image Annotator\n" + "You are a professional image annotator. Please write an image caption based on the input image:\n" + "1. Write the caption using natural, descriptive language without structured formats or rich text.\n" + "2. Enrich caption details by including:\n" + " - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n" + " - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, " + "attachment relations, action relations, comparative relations, causal relations, and so on\n" + " - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n" + " - Identify the text clearly visible in the image, without translation or explanation, " + "and highlight it in the caption with quotation marks\n" + "3. Maintain authenticity and accuracy:\n" + " - Avoid generalizations\n" + " - Describe all visible information in the image, while do not add information not explicitly shown in the image\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) + +QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = ( + "<|im_start|>system\n" + "You are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + "# 图像标注器\n" + "你是一个专业的图像标注器。请基于输入图像,撰写图注:\n" + "1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n" + "2. 通过加入以下内容,丰富图注细节:\n" + " - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n" + " - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n" + " - 环境细节:例如天气、光照、颜色、纹理、气氛等\n" + " - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n" + "3. 保持真实性与准确性:\n" + " - 不要使用笼统的描述\n" + " - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n" + "<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n" + "<|im_start|>assistant\n" +) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py deleted file mode 100644 index 3e788bf94741..000000000000 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -SDXL_NODE_TYPES_PARAMS_MAP = { - "controlnet": { - "inputs": [ - "control_image", - "controlnet_conditioning_scale", - "control_guidance_start", - "control_guidance_end", - "height", - "width", - ], - "model_inputs": [ - "controlnet", - ], - "outputs": [ - "controlnet_out", - ], - "block_names": [None], - }, - "denoise": { - "inputs": [ - "embeddings", - "width", - "height", - "seed", - "num_inference_steps", - "guidance_scale", - "image_latents", - "strength", - # custom adapters coming in as inputs - "controlnet", - # ip_adapter is optional and custom; include if available - "ip_adapter", - ], - "model_inputs": [ - "unet", - "guider", - "scheduler", - ], - "outputs": [ - "latents", - "latents_preview", - ], - "block_names": ["denoise"], - }, - "vae_encoder": { - "inputs": [ - "image", - "width", - "height", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "image_latents", - ], - "block_names": ["vae_encoder"], - }, - "text_encoder": { - "inputs": [ - "prompt", - "negative_prompt", - ], - "model_inputs": [ - "text_encoders", - ], - "outputs": [ - "embeddings", - ], - "block_names": ["text_encoder"], - }, - "decoder": { - "inputs": [ - "latents", - ], - "model_inputs": [ - "vae", - ], - "outputs": [ - "images", - ], - "block_names": ["decode"], - }, -} diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index dc49df8eab8c..4fd69c6ca6ab 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -530,6 +530,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe device = components._execution_device dtype = torch.float32 + vae_dtype = components.vae.dtype height = block_state.height or components.default_height width = block_state.width or components.default_width @@ -555,7 +556,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe vae=components.vae, generator=block_state.generator, device=device, - dtype=dtype, + dtype=vae_dtype, latent_channels=components.num_channels_latents, ) @@ -627,6 +628,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe device = components._execution_device dtype = torch.float32 + vae_dtype = components.vae.dtype height = block_state.height or components.default_height width = block_state.width or components.default_width @@ -659,7 +661,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe vae=components.vae, generator=block_state.generator, device=device, - dtype=dtype, + dtype=vae_dtype, latent_channels=components.num_channels_latents, ) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index b3b70b2f9be1..905111bcf42d 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -84,7 +84,7 @@ def description(self): class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanVaeImageEncoderStep] - block_names = ["image_resize", "vae_image_encoder"] + block_names = ["image_resize", "vae_encoder"] @property def description(self): @@ -142,7 +142,7 @@ def description(self): class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep] - block_names = ["image_resize", "last_image_resize", "vae_image_encoder"] + block_names = ["image_resize", "last_image_resize", "vae_encoder"] @property def description(self): @@ -203,7 +203,7 @@ def description(self): ## vae encoder class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep] - block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"] + block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"] block_trigger_inputs = ["last_image", "image"] @property @@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks): block_names = [ "text_encoder", "image_encoder", - "vae_image_encoder", + "vae_encoder", "denoise", "decode", ] @@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks): ] block_names = [ "text_encoder", - "vae_image_encoder", + "vae_encoder", "denoise", "decode", ] @@ -384,7 +384,7 @@ def description(self): [ ("image_resize", WanImageResizeStep), ("image_encoder", WanImage2VideoImageEncoderStep), - ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("vae_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])), ("set_timesteps", WanSetTimestepsStep), @@ -401,7 +401,7 @@ def description(self): ("image_resize", WanImageResizeStep), ("last_image_resize", WanImageCropResizeStep), ("image_encoder", WanFLF2VImageEncoderStep), - ("vae_image_encoder", WanFLF2VVaeImageEncoderStep), + ("vae_encoder", WanFLF2VVaeImageEncoderStep), ("input", WanTextInputStep), ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])), ("set_timesteps", WanSetTimestepsStep), @@ -416,7 +416,7 @@ def description(self): [ ("text_encoder", WanTextEncoderStep), ("image_encoder", WanAutoImageEncoderStep), - ("vae_image_encoder", WanAutoVaeImageEncoderStep), + ("vae_encoder", WanAutoVaeImageEncoderStep), ("denoise", WanAutoDenoiseStep), ("decode", WanImageVaeDecoderStep), ] @@ -438,7 +438,7 @@ def description(self): IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict( [ ("image_resize", WanImageResizeStep), - ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("vae_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), @@ -450,7 +450,7 @@ def description(self): AUTO_BLOCKS_WAN22 = InsertableDict( [ ("text_encoder", WanTextEncoderStep), - ("vae_image_encoder", WanAutoVaeImageEncoderStep), + ("vae_encoder", WanAutoVaeImageEncoderStep), ("denoise", Wan22AutoDenoiseStep), ("decode", WanImageVaeDecoderStep), ] diff --git a/src/diffusers/modular_pipelines/z_image/__init__.py b/src/diffusers/modular_pipelines/z_image/__init__.py new file mode 100644 index 000000000000..c8a8c14396c0 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/__init__.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["decoders"] = ["ZImageVaeDecoderStep"] + _import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "ZImageAutoBlocks", + ] + _import_structure["modular_pipeline"] = ["ZImageModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .decoders import ZImageVaeDecoderStep + from .encoders import ZImageTextEncoderStep + from .modular_blocks import ( + ALL_BLOCKS, + ZImageAutoBlocks, + ) + from .modular_pipeline import ZImageModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/z_image/before_denoise.py b/src/diffusers/modular_pipelines/z_image/before_denoise.py new file mode 100644 index 000000000000..35ea768f12c3 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/before_denoise.py @@ -0,0 +1,621 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch + +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> Tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 dimensions. + Expected shapes: [batch, channels, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension. + By default, it is 16 + Returns: + Tuple[int, int]: The calculated image dimensions as (height, width) + """ + latent_height, latent_width = latents.shape[2:] + height = latent_height * vae_scale_factor_spatial // 2 + width = latent_width * vae_scale_factor_spatial // 2 + + return height, width + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageTextInputStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_images_per_prompt." + ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + type_hint=List[torch.Tensor], + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=List[torch.Tensor], + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `transformer.dtype`)", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if not isinstance(block_state.prompt_embeds, list): + raise ValueError( + f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}." + ) + if not isinstance(block_state.negative_prompt_embeds, list): + raise ValueError( + f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}." + ) + if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds): + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but" + f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`" + f" {len(block_state.negative_prompt_embeds)}." + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = len(block_state.prompt_embeds) + block_state.dtype = block_state.prompt_embeds[0].dtype + + if block_state.num_images_per_prompt > 1: + prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)] + block_state.prompt_embeds = prompt_embeds + + if block_state.negative_prompt_embeds is not None: + negative_prompt_embeds = [ + npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt) + ] + block_state.negative_prompt_embeds = negative_prompt_embeds + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageAdditionalInputsStep(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (List[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_latents"]. + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process image_latents (default behavior) ZImageAdditionalInputsStep() + + # Configure to process multiple image latent inputs + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"]) + + # Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + def check_inputs(self, components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + @staticmethod + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (comp.vae_scale_factor * 2)) + width = 2 * (int(width) // (comp.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + + block_state.latents = self.prepare_latents( + components, + batch_size=block_state.batch_size * block_state.num_images_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageSetTimestepsStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("num_inference_steps", default=9), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process" + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3] + image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify + + mu = calculate_shift( + image_seq_len, + base_seq_len=components.scheduler.config.get("base_image_seq_len", 256), + max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096), + base_shift=components.scheduler.config.get("base_shift", 0.5), + max_shift=components.scheduler.config.get("max_shift", 1.15), + ) + components.scheduler.sigma_min = 0.0 + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=block_state.sigmas, + mu=mu, + ) + + self.set_block_state(state, block_state) + return components, state + + +class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("timesteps", required=True), + InputParam("num_inference_steps", required=True), + InputParam("strength", default=0.6), + ] + + def check_inputs(self, components, block_state): + if block_state.strength < 0.0 or block_state.strength > 1.0: + raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}") + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps) + + t_start = int(max(block_state.num_inference_steps - init_timestep, 0)) + timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :] + if hasattr(components.scheduler, "set_begin_index"): + components.scheduler.set_begin_index(t_start * components.scheduler.order) + + block_state.timesteps = timesteps + block_state.num_inference_steps = block_state.num_inference_steps - t_start + + self.set_block_state(state, block_state) + return components, state + + +class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("latents", required=True), + InputParam("image_latents", required=True), + InputParam("timesteps", required=True), + ] + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/decoders.py b/src/diffusers/modular_pipelines/z_image/decoders.py new file mode 100644 index 000000000000..cdb6a2e5eac1 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/decoders.py @@ -0,0 +1,91 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageVaeDecoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam( + "latents", + required=True, + ), + InputParam( + name="output_type", + default="pil", + type_hint=str, + description="The type of the output images, can be 'pil', 'np', 'pt'", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "images", + type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]], + description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + latents = block_state.latents.to(vae_dtype) + latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor + + block_state.images = components.vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py new file mode 100644 index 000000000000..5f76a8459fde --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -0,0 +1,314 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ZImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import ZImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents = block_state.latents.unsqueeze(2).to( + block_state.dtype + ) # [batch_size, num_channels, 1, height, width] + block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width] + + timestep = t.expand(latents.shape[0]).to(block_state.dtype) + timestep = (1000 - timestep) / 1000 + block_state.timestep = timestep + return components, block_state + + +class ZImageLoopDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + def __init__( + self, + guider_input_fields: Dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", ZImageTransformer2DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + inputs = [ + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="denoiser_input_fields", + description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ), + ] + guider_input_names = [] + uncond_guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.append(value[0]) + uncond_guider_input_names.append(value[1]) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True)) + for name in uncond_guider_input_names: + inputs.append(InputParam(name=name)) + return inputs + + @torch.no_grad() + def __call__( + self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + + def _convert_dtype(v, dtype): + if isinstance(v, torch.Tensor): + return v.to(dtype) + elif isinstance(v, list): + return [_convert_dtype(t, dtype) for t in v] + return v + + cond_kwargs = { + k: _convert_dtype(v, block_state.dtype) + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + model_out_list = components.transformer( + x=block_state.latent_model_input, + t=block_state.timestep, + return_dict=False, + **cond_kwargs, + )[0] + noise_pred = torch.stack(model_out_list, dim=0).squeeze(2) + guider_state_batch.noise_pred = -noise_pred + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class ZImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "z-image" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ZImageDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def loop_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class ZImageDenoiseStep(ZImageDenoiseLoopWrapper): + block_classes = [ + ZImageLoopBeforeDenoiser, + ZImageLoopDenoiser( + guider_input_fields={ + "cap_feats": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + ZImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `ZImageLoopBeforeDenoiser`\n" + " - `ZImageLoopDenoiser`\n" + " - `ZImageLoopAfterDenoiser`\n" + "This block supports text-to-image and image-to-image tasks for Z-Image." + ) diff --git a/src/diffusers/modular_pipelines/z_image/encoders.py b/src/diffusers/modular_pipelines/z_image/encoders.py new file mode 100644 index 000000000000..f5769fe2deec --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/encoders.py @@ -0,0 +1,344 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import PIL +import torch +from transformers import Qwen2Tokenizer, Qwen3Model + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import is_ftfy_available, logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ZImageModularPipeline + + +if is_ftfy_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_qwen_prompt_embeds( + text_encoder: Qwen3Model, + tokenizer: Qwen2Tokenizer, + prompt: Union[str, List[str]], + device: torch.device, + max_sequence_length: int = 512, +) -> List[torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + prompt_embeds_list = [] + + for i in range(len(prompt_embeds)): + prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]]) + + return prompt_embeds_list + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + image_tensor: torch.Tensor, + vae: AutoencoderKL, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(image_tensor, torch.Tensor): + raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.") + + if isinstance(generator, list) and len(generator) != image_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}." + ) + + image_tensor = image_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i]) + for i in range(image_tensor.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class ZImageTextEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen3Model), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0, "enabled": False}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("max_sequence_length", default=512), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=List[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=List[torch.Tensor], + kwargs_type="denoiser_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: Optional[torch.device] = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_prompt_embeds = None + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = get_qwen_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state + + +class ZImageVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "z-image" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8 * 2}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + + def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.image + + device = components._execution_device + dtype = torch.float32 + vae_dtype = components.vae.dtype + + image_tensor = components.image_processor.preprocess( + image, height=block_state.height, width=block_state.width + ).to(device=device, dtype=dtype) + + block_state.image_latents = encode_vae_image( + image_tensor=image_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=vae_dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/z_image/modular_blocks.py b/src/diffusers/modular_pipelines/z_image/modular_blocks.py new file mode 100644 index 000000000000..a54baeccaf0c --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/modular_blocks.py @@ -0,0 +1,191 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + ZImageAdditionalInputsStep, + ZImagePrepareLatentsStep, + ZImagePrepareLatentswithImageStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImageTextInputStep, +) +from .decoders import ZImageVaeDecoderStep +from .denoise import ( + ZImageDenoiseStep, +) +from .encoders import ( + ZImageTextEncoderStep, + ZImageVaeImageEncoderStep, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# z-image +# text2image +class ZImageCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + ZImageTextInputStep, + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageDenoiseStep, + ] + block_names = ["input", "prepare_latents", "set_timesteps", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `ZImagePrepareLatentsStep` is used to prepare the latents\n" + + " - `ZImageSetTimestepsStep` is used to set the timesteps\n" + + " - `ZImageDenoiseStep` is used to denoise the latents\n" + ) + + +# z-image: image2image +## denoise +class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + ZImageTextInputStep, + ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]), + ZImagePrepareLatentsStep, + ZImageSetTimestepsStep, + ZImageSetTimestepsWithStrengthStep, + ZImagePrepareLatentswithImageStep, + ZImageDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "prepare_latents", + "set_timesteps", + "set_timesteps_with_strength", + "prepare_latents_with_image", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `ZImagePrepareLatentsStep` is used to prepare the latents\n" + + " - `ZImageSetTimestepsStep` is used to set the timesteps\n" + + " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n" + + " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n" + + " - `ZImageDenoiseStep` is used to denoise the latents\n" + ) + + +## auto blocks +class ZImageAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + ZImageImage2ImageCoreDenoiseStep, + ZImageCoreDenoiseStep, + ] + block_names = ["image2image", "text2image"] + block_trigger_inputs = ["image_latents", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2image and image2image tasks." + " - `ZImageCoreDenoiseStep` (text2image) for text2image tasks." + " - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks." + + " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n" + + " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n" + ) + + +class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks): + block_classes = [ZImageVaeImageEncoderStep] + block_names = ["vae_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self) -> str: + return "Vae Image Encoder step that encode the image to generate the image latents" + +"This is an auto pipeline block that works for image2image tasks." + +" - `ZImageVaeImageEncoderStep` is used when `image` is provided." + +" - if `image` is not provided, step will be skipped." + + +class ZImageAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + ZImageTextEncoderStep, + ZImageAutoVaeImageEncoderStep, + ZImageAutoDenoiseStep, + ZImageVaeDecoderStep, + ] + block_names = ["text_encoder", "vae_encoder", "denoise", "decode"] + + @property + def description(self) -> str: + return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n" + +" - for text-to-image generation, all you need to provide is `prompt`\n" + +" - for image-to-image generation, you need to provide `image`\n" + +" - if `image` is not provided, step will be skipped." + + +# presets +TEXT2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("input", ZImageTextInputStep), + ("prepare_latents", ZImagePrepareLatentsStep), + ("set_timesteps", ZImageSetTimestepsStep), + ("denoise", ZImageDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + +IMAGE2IMAGE_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("vae_encoder", ZImageVaeImageEncoderStep), + ("input", ZImageTextInputStep), + ("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])), + ("prepare_latents", ZImagePrepareLatentsStep), + ("set_timesteps", ZImageSetTimestepsStep), + ("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep), + ("prepare_latents_with_image", ZImagePrepareLatentswithImageStep), + ("denoise", ZImageDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", ZImageTextEncoderStep), + ("vae_encoder", ZImageAutoVaeImageEncoderStep), + ("denoise", ZImageAutoDenoiseStep), + ("decode", ZImageVaeDecoderStep), + ] +) + +ALL_BLOCKS = { + "text2image": TEXT2IMAGE_BLOCKS, + "image2image": IMAGE2IMAGE_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/z_image/modular_pipeline.py b/src/diffusers/modular_pipelines/z_image/modular_pipeline.py new file mode 100644 index 000000000000..f1d8e53a3639 --- /dev/null +++ b/src/diffusers/modular_pipelines/z_image/modular_pipeline.py @@ -0,0 +1,72 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import ZImageLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZImageModularPipeline( + ModularPipeline, + ZImageLoraLoaderMixin, +): + """ + A ModularPipeline for Z-Image. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "ZImageAutoBlocks" + + @property + def default_height(self): + return 1024 + + @property + def default_width(self): + return 1024 + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor_spatial = 16 + if hasattr(self, "image_processor") and self.image_processor is not None: + vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor + return vae_scale_factor_spatial + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3d669aecf556..65378631a172 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -15,6 +15,7 @@ is_torch_available, is_torch_npu_available, is_transformers_available, + is_transformers_version, ) @@ -128,8 +129,8 @@ "AnimateDiffVideoToVideoControlNetPipeline", ] _import_structure["bria"] = ["BriaPipeline"] - _import_structure["bria_fibo"] = ["BriaFiboPipeline"] - _import_structure["flux2"] = ["Flux2Pipeline"] + _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] + _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -154,7 +155,7 @@ "AudioLDM2UNet2DConditionModel", ] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] - _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"] + _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline", "ChromaInpaintPipeline"] _import_structure["cogvideo"] = [ "CogVideoXPipeline", "CogVideoXImageToVideoPipeline", @@ -165,6 +166,7 @@ _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["cosmos"] = [ + "Cosmos2_5_PredictBasePipeline", "Cosmos2TextToImagePipeline", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", @@ -287,10 +289,13 @@ "LTXImageToVideoPipeline", "LTXConditionPipeline", "LTXLatentUpsamplePipeline", + "LTXI2VLongMultiPromptPipeline", ] + _import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] + _import_structure["longcat_image"] = ["LongCatImagePipeline", "LongCatImageEditPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -404,7 +409,13 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] - _import_structure["z_image"] = ["ZImagePipeline"] + _import_structure["z_image"] = [ + "ZImageImg2ImgPipeline", + "ZImagePipeline", + "ZImageControlNetPipeline", + "ZImageControlNetInpaintPipeline", + "ZImageOmniPipeline", + ] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -421,8 +432,11 @@ "QwenImageEditInpaintPipeline", "QwenImageControlNetInpaintPipeline", "QwenImageControlNetPipeline", + "QwenImageLayeredPipeline", ] _import_structure["chronoedit"] = ["ChronoEditPipeline"] + _import_structure["glm_image"] = ["GlmImagePipeline"] + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -583,8 +597,8 @@ from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline from .bria import BriaPipeline - from .bria_fibo import BriaFiboPipeline - from .chroma import ChromaImg2ImgPipeline, ChromaPipeline + from .bria_fibo import BriaFiboEditPipeline, BriaFiboPipeline + from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline from .chronoedit import ChronoEditPipeline from .cogvideo import ( CogVideoXFunControlPipeline, @@ -615,6 +629,7 @@ StableDiffusionXLControlNetXSPipeline, ) from .cosmos import ( + Cosmos2_5_PredictBasePipeline, Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, CosmosTextToWorldPipeline, @@ -663,7 +678,8 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) - from .flux2 import Flux2Pipeline + from .flux2 import Flux2KleinPipeline, Flux2Pipeline + from .glm_image import GlmImagePipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline from .hunyuan_video import ( @@ -718,7 +734,15 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) - from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline + from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline + from .ltx import ( + LTXConditionPipeline, + LTXI2VLongMultiPromptPipeline, + LTXImageToVideoPipeline, + LTXLatentUpsamplePipeline, + LTXPipeline, + ) + from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline from .lucy import LucyEditPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline @@ -762,6 +786,7 @@ QwenImageEditPlusPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, + QwenImageLayeredPipeline, QwenImagePipeline, ) from .sana import ( @@ -841,7 +866,13 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImagePipeline + from .z_image import ( + ZImageControlNetInpaintPipeline, + ZImageControlNetPipeline, + ZImageImg2ImgPipeline, + ZImageOmniPipeline, + ZImagePipeline, + ) try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 3be0129088fb..42083378d465 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -887,7 +887,13 @@ def __call__( prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) self.scheduler.set_timesteps(num_inference_steps, device=device) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 6f3a609aba4a..51a9a31c4259 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -897,16 +897,20 @@ def __call__( dtype = self.dtype # 3. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device if not enforce_inference_steps: timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) else: denoising_inference_steps = int(num_inference_steps / strength) timesteps, denoising_inference_steps = retrieve_timesteps( - self.scheduler, denoising_inference_steps, device, timesteps, sigmas + self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas ) timesteps = timesteps[-num_inference_steps:] latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py index b00f344598ad..c3ac7df2cc8c 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -1100,16 +1100,20 @@ def __call__( dtype = self.dtype # 3. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device if not enforce_inference_steps: timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) else: denoising_inference_steps = int(num_inference_steps / strength) timesteps, denoising_inference_steps = retrieve_timesteps( - self.scheduler, denoising_inference_steps, device, timesteps, sigmas + self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas ) timesteps = timesteps[-num_inference_steps:] latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index bb9884e41381..1d75e4bef31e 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -586,7 +586,13 @@ def __call__( # 4. Prepare timesteps # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 044d854390e4..5ee44190e23b 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -52,6 +52,8 @@ FluxKontextPipeline, FluxPipeline, ) +from .flux2 import Flux2KleinPipeline, Flux2Pipeline +from .glm_image import GlmImagePipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -73,6 +75,7 @@ from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .lumina import LuminaPipeline from .lumina2 import Lumina2Pipeline +from .ovis_image import OvisImagePipeline from .pag import ( HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, @@ -98,6 +101,7 @@ QwenImageEditPlusPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, + QwenImageLayeredPipeline, QwenImagePipeline, ) from .sana import SanaPipeline @@ -119,6 +123,13 @@ ) from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline +from .z_image import ( + ZImageControlNetInpaintPipeline, + ZImageControlNetPipeline, + ZImageImg2ImgPipeline, + ZImageOmniPipeline, + ZImagePipeline, +) AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -154,14 +165,22 @@ ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), ("flux-kontext", FluxKontextPipeline), + ("flux2-klein", Flux2KleinPipeline), + ("flux2", Flux2Pipeline), ("lumina", LuminaPipeline), ("lumina2", Lumina2Pipeline), ("chroma", ChromaPipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), + ("glm_image", GlmImagePipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), + ("z-image", ZImagePipeline), + ("z-image-controlnet", ZImageControlNetPipeline), + ("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline), + ("z-image-omni", ZImageOmniPipeline), + ("ovis", OvisImagePipeline), ] ) @@ -186,9 +205,13 @@ ("flux-controlnet", FluxControlNetImg2ImgPipeline), ("flux-control", FluxControlImg2ImgPipeline), ("flux-kontext", FluxKontextPipeline), + ("flux2-klein", Flux2KleinPipeline), + ("flux2", Flux2Pipeline), ("qwenimage", QwenImageImg2ImgPipeline), ("qwenimage-edit", QwenImageEditPipeline), ("qwenimage-edit-plus", QwenImageEditPlusPipeline), + ("qwenimage-layered", QwenImageLayeredPipeline), + ("z-image", ZImageImg2ImgPipeline), ] ) diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py index 206a463b394b..8dd77270902c 100644 --- a/src/diffusers/pipelines/bria_fibo/__init__.py +++ b/src/diffusers/pipelines/bria_fibo/__init__.py @@ -23,6 +23,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"] + _import_structure["pipeline_bria_fibo_edit"] = ["BriaFiboEditPipeline"] + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +35,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_bria_fibo import BriaFiboPipeline + from .pipeline_bria_fibo_edit import BriaFiboEditPipeline else: import sys diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py new file mode 100644 index 000000000000..aae8fc7367da --- /dev/null +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo_edit.py @@ -0,0 +1,1133 @@ +# Copyright (c) Bria.ai. All rights reserved. +# +# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0). +# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/ +# +# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit, +# indicate if changes were made, and do not use the material for commercial purposes. +# +# See the license for further details. + +import json +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxLoraLoaderMixin +from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan +from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput +from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +PipelineMaskInput = Union[ + torch.FloatTensor, Image.Image, List[Image.Image], List[torch.FloatTensor], np.ndarray, List[np.ndarray] +] + +# TODO: Update example docstring +EXAMPLE_DOC_STRING = """ + Example: + ```python + import torch + from diffusers import BriaFiboEditPipeline + from diffusers.modular_pipelines import ModularPipeline + + torch.set_grad_enabled(False) + vlm_pipe = ModularPipelineBlocks.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True) + vlm_pipe = vlm_pipe.init_pipeline() + + pipe = BriaFiboEditPipeline.from_pretrained( + "briaai/fibo-edit", + torch_dtype=torch.bfloat16, + ) + pipe.to("cuda") + + output = vlm_pipe( + prompt="A hyper-detailed, ultra-fluffy owl sitting in the trees at night, looking directly at the camera with wide, adorable, expressive eyes. Its feathers are soft and voluminous, catching the cool moonlight with subtle silver highlights. The owl's gaze is curious and full of charm, giving it a whimsical, storybook-like personality." + ) + json_prompt_generate = json.loads(output.values["json_prompt"]) + + image = Image.open("image_generate.png") + + edit_prompt = "Make the owl to be a cat" + + json_prompt_generate["edit_instruction"] = edit_prompt + + results_generate = pipe( + prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=3.5, image=image, output_type="np" + ) + ``` +""" + +PREFERRED_RESOLUTION = { + 256 * 256: [(208, 304), (224, 288), (256, 256), (288, 224), (304, 208), (320, 192), (336, 192)], + 512 * 512: [ + (416, 624), + (432, 592), + (464, 560), + (512, 512), + (544, 480), + (576, 448), + (592, 432), + (608, 416), + (624, 416), + (640, 400), + (672, 384), + (704, 368), + ], + 1024 * 1024: [ + (832, 1248), + (880, 1184), + (912, 1136), + (1024, 1024), + (1136, 912), + (1184, 880), + (1216, 848), + (1248, 832), + (1248, 832), + (1264, 816), + (1296, 800), + (1360, 768), + ], +} + + +def is_valid_edit_json(json_input: str | dict): + """ + Check if the input is a valid JSON string or dict with an "edit_instruction" key. + + Args: + json_input (`str` or `dict`): + The JSON string or dict to check. + + Returns: + `bool`: True if the input is a valid JSON string or dict with an "edit_instruction" key, False otherwise. + """ + try: + if isinstance(json_input, str) and "edit_instruction" in json_input: + json.loads(json_input) + return True + elif isinstance(json_input, dict) and "edit_instruction" in json_input: + return True + else: + return False + except json.JSONDecodeError: + return False + + +def is_valid_mask(mask: PipelineMaskInput): + """ + Check if the mask is a valid mask. + """ + if isinstance(mask, torch.Tensor): + return True + elif isinstance(mask, Image.Image): + return True + elif isinstance(mask, list): + return all(isinstance(m, (torch.Tensor, Image.Image, np.ndarray)) for m in mask) + elif isinstance(mask, np.ndarray): + return mask.ndim in [2, 3] and mask.min() >= 0 and mask.max() <= 1 + else: + return False + + +def get_mask_size(mask: PipelineMaskInput): + """ + Get the size of the mask. + """ + if isinstance(mask, torch.Tensor): + return mask.shape[-2:] + elif isinstance(mask, Image.Image): + return mask.size[::-1] # (height, width) + elif isinstance(mask, list): + return [get_mask_size(m) for m in mask] + elif isinstance(mask, np.ndarray): + return mask.shape[-2:] + else: + return None + + +def get_image_size(image: PipelineImageInput): + """ + Get the size of the image. + """ + if isinstance(image, torch.Tensor): + return image.shape[-2:] + elif isinstance(image, Image.Image): + return image.size[::-1] # (height, width) + elif isinstance(image, list): + return [get_image_size(i) for i in image] + else: + return None + + +def paste_mask_on_image(mask: PipelineMaskInput, image: PipelineImageInput): + """convert mask and image to PIL Images and paste the mask on the image""" + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, Image.Image): + pass + elif isinstance(mask, list): + mask = mask[0] + if isinstance(mask, torch.Tensor): + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = Image.fromarray((mask.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + elif isinstance(mask, np.ndarray): + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, Image.Image): + pass + elif isinstance(image, list): + image = image[0] + if isinstance(image, torch.Tensor): + if image.ndim == 3: + image = image.permute(1, 2, 0) + image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + elif isinstance(image, np.ndarray): + image = Image.fromarray((image * 255).astype(np.uint8)) + + mask = mask.convert("L") + image = image.convert("RGB") + gray_color = (128, 128, 128) + gray_img = Image.new("RGB", image.size, gray_color) + image = Image.composite(gray_img, image, mask) + return image + + +class BriaFiboEditPipeline(DiffusionPipeline, FluxLoraLoaderMixin): + r""" + Args: + transformer (`BriaFiboTransformer2DModel`): + The transformer model for 2D diffusion modeling. + scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`): + Scheduler to be used with `transformer` to denoise the encoded latents. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder for encoding and decoding images to and from latent representations. + text_encoder (`SmolLM3ForCausalLM`): + Text encoder for processing input prompts. + tokenizer (`AutoTokenizer`): + Tokenizer used for processing the input text prompts for the text_encoder. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + transformer: BriaFiboTransformer2DModel, + scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers], + vae: AutoencoderKLWan, + text_encoder: SmolLM3ForCausalLM, + tokenizer: AutoTokenizer, + ): + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor = 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # * 2) + self.default_sample_size = 32 # 64 + + def get_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + if not prompt: + raise ValueError("`prompt` must be a non-empty string or list of strings.") + + batch_size = len(prompt) + bot_token_id = 128000 + + text_encoder_device = device if device is not None else torch.device("cpu") + if not isinstance(text_encoder_device, torch.device): + text_encoder_device = torch.device(text_encoder_device) + + if all(p == "" for p in prompt): + input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device) + attention_mask = torch.ones_like(input_ids) + else: + tokenized = self.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized.input_ids.to(text_encoder_device) + attention_mask = tokenized.attention_mask.to(text_encoder_device) + + if any(p == "" for p in prompt): + empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device) + input_ids[empty_rows] = bot_token_id + attention_mask[empty_rows] = 1 + + encoder_outputs = self.text_encoder( + input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_outputs.hidden_states + + prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1) + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + hidden_states = tuple( + layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states + ) + attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) + + return prompt_embeds, hidden_states, attention_mask + + @staticmethod + def pad_embedding(prompt_embeds, max_tokens, attention_mask=None): + # Pad embeddings to `max_tokens` while preserving the mask of real tokens. + batch_size, seq_len, dim = prompt_embeds.shape + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device) + else: + attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if max_tokens < seq_len: + raise ValueError("`max_tokens` must be greater or equal to the current sequence length.") + + if max_tokens > seq_len: + pad_length = max_tokens - seq_len + padding = torch.zeros( + (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, padding], dim=1) + + mask_padding = torch.zeros( + (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 3000, + lora_scale: Optional[float] = None, + ): + r""" + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + guidance_scale (`float`): + Guidance scale for classifier free guidance. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_attention_mask = None + negative_prompt_attention_mask = None + if prompt_embeds is None: + prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype) + prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers] + + if guidance_scale > 1: + if isinstance(negative_prompt, list) and negative_prompt[0] is None: + negative_prompt = "" + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype) + negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers] + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + # Pad to longest + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + + if negative_prompt_embeds is not None: + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype + ) + max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1]) + + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers] + + negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding( + negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask + ) + negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers] + else: + max_tokens = prompt_embeds.shape[1] + prompt_embeds, prompt_attention_mask = self.pad_embedding( + prompt_embeds, max_tokens, attention_mask=prompt_attention_mask + ) + negative_prompt_layers = None + + dtype = self.text_encoder.dtype + text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype) + + return ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @staticmethod + # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _unpack_latents_no_patch(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels) + latents = latents.permute(0, 3, 1, 2) + + return latents + + @staticmethod + def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width): + latents = latents.permute(0, 2, 3, 1) + latents = latents.reshape(batch_size, height * width, num_channels_latents) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + do_patching=False, + ): + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if do_patching: + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + else: + latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + + return latents, latent_image_ids + + @staticmethod + def _prepare_attention_mask(attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[PipelineImageInput] = None, + mask: Optional[PipelineMaskInput] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + seed: Optional[int] = None, + guidance_scale: float = 5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 3000, + do_patching=False, + _auto_resize: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or `torch.FloatTensor`, *optional*): + The image to guide the image generation. If not defined, the pipeline will generate an image from + scratch. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + seed (`int`, *optional*): + A seed used to make generation deterministic. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`. + do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching. + Examples: + Returns: + [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if height is None or width is None: + if image is not None: + image_height, image_width = self.image_processor.get_default_height_width(image) + if _auto_resize: + image_width, image_height = min( + PREFERRED_RESOLUTION[1024 * 1024], + key=lambda size: abs(size[0] / size[1] - image_width / image_height), + ) + width, height = image_width, image_height + else: + raise ValueError("You must provide either an image or both height and width.") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + seed=seed, + image=image, + mask=mask, + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + if mask is not None and image is not None: + image = paste_mask_on_image(mask, image) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + + if prompt is not None and is_valid_edit_json(prompt): + prompt = json.dumps(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if generator is None and seed is not None: + generator = torch.Generator(device=device).manual_seed(seed) + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_layers, + negative_prompt_layers, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_images_per_prompt=num_images_per_prompt, + lora_scale=lora_scale, + ) + prompt_batch_size = prompt_embeds.shape[0] + + if guidance_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_layers = [ + torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers)) + ] + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + total_num_layers_transformer = len(self.transformer.transformer_blocks) + len( + self.transformer.single_transformer_blocks + ) + if len(prompt_layers) >= total_num_layers_transformer: + # remove first layers + prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :] + else: + # duplicate last layer + prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers)) + + # Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, height, width) + image = self.image_processor.preprocess(image, height, width) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + if do_patching: + num_channels_latents = int(num_channels_latents / 4) + + latents, latent_image_ids = self.prepare_latents( + prompt_batch_size, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + do_patching, + ) + + if image is not None: + image_latents, image_ids = self.prepare_image_latents( + image=image, + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + latent_image_ids = torch.cat([latent_image_ids, image_ids], dim=0) # dim 0 is sequence dimension + else: + image_latents = None + + latent_attention_mask = torch.ones( + [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device + ) + if guidance_scale > 1: + latent_attention_mask = latent_attention_mask.repeat(2, 1) + + if image_latents is None: + attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1) + else: + image_latent_attention_mask = torch.ones( + [image_latents.shape[0], image_latents.shape[1]], + dtype=image_latents.dtype, + device=image_latents.device, + ) + if guidance_scale > 1: + image_latent_attention_mask = image_latent_attention_mask.repeat(2, 1) + attention_mask = torch.cat( + [prompt_attention_mask, latent_attention_mask, image_latent_attention_mask], dim=1 + ) + + attention_mask = self.create_attention_matrix(attention_mask) # batch, seq => batch, seq, seq + attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + self._joint_attention_kwargs["attention_mask"] = attention_mask + + # Adapt scheduler to dynamic shifting (resolution dependent) + + if do_patching: + seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2)) + else: + seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor) + + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + + mu = calculate_shift( + seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + + # Init sigmas and timesteps according to shift size + # This changes the scheduler in-place according to the dynamic scheduling + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps=num_inference_steps, + device=device, + timesteps=None, + sigmas=sigmas, + mu=mu, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # Support old different diffusers versions + if len(latent_image_ids.shape) == 3: + latent_image_ids = latent_image_ids[0] + + if len(text_ids.shape) == 3: + text_ids = text_ids[0] + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents + + if image_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latent_model_input] * 2) if guidance_scale > 1 else latent_model_input + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to( + device=latent_model_input.device, dtype=latent_model_input.dtype + ) + + # This is predicts "v" from flow-matching or eps from diffusion + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] + + # perform guidance + if guidance_scale > 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred[:, : latents.shape[1], ...], t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + + else: + if do_patching: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + else: + latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor) + + latents = latents.unsqueeze(dim=2) + latents_device = latents[0].device + latents_dtype = latents[0].dtype + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents_device, latents_dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents_device, latents_dtype + ) + latents_scaled = [latent / latents_std + latents_mean for latent in latents] + latents_scaled = torch.cat(latents_scaled, dim=0) + image = [] + for scaled_latent in latents_scaled: + curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0] + curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type) + image.append(curr_image) + if len(image) == 1: + image = image[0] + else: + image = np.stack(image, axis=0) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return BriaFiboPipelineOutput(images=image) + + def prepare_image_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + image = image.to(device=device, dtype=dtype) + + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + # scaling + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) + + image_latents_cthw = self.vae.encode(image.unsqueeze(2)).latent_dist.mean + latents_scaled = [(latent - latents_mean) * latents_std for latent in image_latents_cthw] + image_latents_cthw = torch.concat(latents_scaled, dim=0) + image_latents_bchw = image_latents_cthw[:, :, 0, :, :] + + image_latent_height, image_latent_width = image_latents_bchw.shape[2:] + image_latents_bsd = self._pack_latents_no_patch( + latents=image_latents_bchw, + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=image_latent_height, + width=image_latent_width, + ) + # breakpoint() + image_ids = self._prepare_latent_image_ids( + batch_size=batch_size, height=image_latent_height, width=image_latent_width, device=device, dtype=dtype + ) + # image ids are the same as latent ids with the first dimension set to 1 instead of 0 + image_ids[..., 0] = 1 + return image_latents_bsd, image_ids + + def check_inputs( + self, + prompt, + seed, + image, + mask, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if seed is not None and not isinstance(seed, int): + raise ValueError("Seed must be an integer") + if image is not None and not isinstance(image, (torch.Tensor, Image.Image, list)): + raise ValueError("Image must be a valid image") + if image is None and mask is not None: + raise ValueError("If mask is provided, image must also be provided") + + if mask is not None and not is_valid_mask(mask): + raise ValueError("Mask must be a valid mask") + + if mask is not None and image is not None and not (get_mask_size(mask) == get_image_size(image)): + raise ValueError("Mask and image must have the same size") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and not is_valid_edit_json(prompt): + raise ValueError(f"`prompt` has to be a valid JSON string or dict but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 3000: + raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}") + + def create_attention_matrix(self, attention_mask): + attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask) + + # convert to 0 - keep, -inf ignore + attention_matrix = torch.where( + attention_matrix == 1, 0.0, -torch.inf + ) # Apply -inf to ignored tokens for nulling softmax score + return attention_matrix diff --git a/src/diffusers/pipelines/chroma/__init__.py b/src/diffusers/pipelines/chroma/__init__.py index d9238b735c41..25069b5543c1 100644 --- a/src/diffusers/pipelines/chroma/__init__.py +++ b/src/diffusers/pipelines/chroma/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_chroma"] = ["ChromaPipeline"] _import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"] + _import_structure["pipeline_chroma_inpainting"] = ["ChromaInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -33,6 +34,7 @@ else: from .pipeline_chroma import ChromaPipeline from .pipeline_chroma_img2img import ChromaImg2ImgPipeline + from .pipeline_chroma_inpainting import ChromaInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py new file mode 100644 index 000000000000..3ea1ece36c87 --- /dev/null +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py @@ -0,0 +1,1184 @@ +""" +ChromaInpaintPipeline implements a text-guided image inpainting pipeline for the lodestones/Chroma1-HD model, based on +the ChromaPipeline from Hugging Face Diffusers:contentReference[oaicite:0]{index=0} and the Stable Diffusion inpainting +approach:contentReference[oaicite:1]{index=1}. +""" + +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ChromaTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..chroma.pipeline_output import ChromaPipelineOutput +from ..pipeline_utils import DiffusionPipeline + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ChromaInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = ChromaInpaintPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("chroma_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ChromaInpaintPipeline( + DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, FluxIPAdapterMixin +): + r""" + The Flux pipeline for image inpainting. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`ChromaTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: ChromaTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 + + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.default_sample_size = 128 + + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str], None] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + tokenizer_mask = text_inputs.attention_mask + + tokenizer_mask = tokenizer_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask, + )[0] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + seq_lengths = tokenizer_mask.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask.size(1), device=device).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str], None] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + do_classifier_free_guidance: bool = True, + max_sequence_length: int = 256, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3, device=device, dtype=dtype) + negative_text_ids = None + + if do_classifier_free_guidance: + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3, device=device, dtype=dtype) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + max_sequence_length=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError( + "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask" + ) + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + + if masked_image.shape[1] == 16: + masked_image_latents = masked_image + else: + masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator) + + masked_image_latents = ( + masked_image_latents - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents( + masked_image_latents, + batch_size, + num_channels_latents, + height, + width, + ) + mask = self._pack_latents( + mask.repeat(1, num_channels_latents, 1, 1), + batch_size, + num_channels_latents, + height, + width, + ) + + return mask, masked_image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def _prepare_attention_mask( + self, + batch_size, + sequence_length, + dtype, + attention_mask=None, + ): + if attention_mask is None: + return attention_mask + + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)], + dim=1, + ) + attention_mask = attention_mask.to(dtype) + + return attention_mask + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 1.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 35): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + strength (`float, *optional*, defaults to 0.9): + Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will + be used as a starting point, adding more noise to it the larger the strength. The number of denoising + steps depends on the amount of noise initially added. When strength is 1, added noise will be maximum + and the denoising process will run for the full number of iterations specified in num_inference_steps. + A value of 1, therefore, essentially ignores image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence. + Chroma requires a single padding token remain unmasked. Please refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + negative_prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative + prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.chroma.ChromaPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + output_type=output_type, + strength=strength, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + image=image, + mask_image=mask_image, + padding_mask_crop=padding_mask_crop, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + do_classifier_free_guidance=self.do_classifier_free_guidance, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_transformer = self.transformer.config.in_channels + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=prompt_attention_mask, + ) + negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=negative_prompt_attention_mask, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + attention_mask=attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # for 64 channel transformer only. + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ChromaPipelineOutput(images=image) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 4ac33b24bbe1..245c794c9c93 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -664,7 +664,13 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) self._num_timesteps = len(timesteps) # 5. Prepare latents diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index c1335839f848..456f0bda1644 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -717,7 +717,13 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) self._num_timesteps = len(timesteps) # 5. Prepare latents diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index c523c9adec98..321f0f073fe7 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -762,7 +762,13 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) self._num_timesteps = len(timesteps) # 5. Prepare latents diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 897dc6d1b70a..e27c572020d6 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -737,7 +737,13 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py index 304a5c5ad00b..46f60d24a467 100644 --- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py +++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -566,7 +566,13 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) self._num_timesteps = len(timesteps) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 22510f5d9d50..9a2d555538d5 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -599,8 +599,12 @@ def __call__( self.scheduler.config.get("base_shift", 0.25), self.scheduler.config.get("max_shift", 0.75), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu ) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index e26b7ba415de..2d6785f791db 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -649,8 +649,12 @@ def __call__( self.scheduler.config.get("base_shift", 0.25), self.scheduler.config.get("max_shift", 0.75), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas, mu=mu ) self._num_timesteps = len(timesteps) # Denoising loop diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index fe0e69314cca..e2fb32688392 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1195,8 +1195,12 @@ def __call__( assert False # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0e2a1441f8f6..283c3f92390c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1344,8 +1344,12 @@ def __call__( assert False # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 94c4c394465b..2ea7307fec32 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -84,7 +84,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index 40cc76cf70d8..99f2958b320e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -1339,8 +1339,12 @@ def __call__( height, width = control_image[0][0].shape[-2:] # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index 2b5684de9511..29a7d6147638 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -17,7 +17,7 @@ import numpy as np import torch -from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -185,7 +185,7 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -229,7 +229,7 @@ def __init__( HunyuanDiT2DMultiControlNetModel, ], text_encoder_2: Optional[T5EncoderModel] = None, - tokenizer_2: Optional[MT5Tokenizer] = None, + tokenizer_2: Optional[T5Tokenizer] = None, requires_safety_checker: bool = True, ): super().__init__() diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index d605eac1f2b1..d721acc77c2a 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -1098,7 +1098,13 @@ def __call__( assert False # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 9d0158c6b654..2071305cdf10 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -15,6 +15,8 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image import torch from transformers import ( CLIPTextModelWithProjection, @@ -39,7 +41,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput @@ -227,6 +229,8 @@ def __init__( feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() + if isinstance(controlnet, (list, tuple)): + controlnet = SD3MultiControlNetModel(controlnet) self.register_modules( vae=vae, @@ -572,14 +576,52 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.check_inputs + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, + height, + width, + image, prompt, prompt_2, prompt_3, - height, - width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, @@ -587,6 +629,11 @@ def check_inputs( negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -669,6 +716,76 @@ def check_inputs( if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, SD3MultiControlNetModel): + if isinstance(prompt, list) and len(prompt) > 1: + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, SD3ControlNetModel): + self.check_image(image, prompt, prompt_embeds) + elif isinstance(controlnet, SD3MultiControlNetModel): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + + # Check `controlnet_conditioning_scale` + if isinstance(controlnet, SD3MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(controlnet, SD3MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents def prepare_latents( self, @@ -1040,11 +1157,12 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( + height, + width, + control_image, prompt, prompt_2, prompt_3, - height, - width, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, negative_prompt_3=negative_prompt_3, @@ -1052,6 +1170,11 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -1119,9 +1242,26 @@ def __call__( width = latent_width * self.vae_scale_factor elif isinstance(self.controlnet, SD3MultiControlNetModel): - raise NotImplementedError("MultiControlNetModel is not supported for SD3ControlNetInpaintingPipeline.") + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_image_with_mask( + image=control_image_, + mask=control_mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + control_images.append(control_image_) + + control_image = control_images else: - assert False + assert ValueError("Controlnet not found. Please check the controlnet model.") if controlnet_pooled_projections is None: controlnet_pooled_projections = torch.zeros_like(pooled_prompt_embeds) @@ -1129,7 +1269,13 @@ def __call__( controlnet_pooled_projections = controlnet_pooled_projections or pooled_prompt_embeds # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/cosmos/__init__.py b/src/diffusers/pipelines/cosmos/__init__.py index 2833c89abd5e..944f16553173 100644 --- a/src/diffusers/pipelines/cosmos/__init__.py +++ b/src/diffusers/pipelines/cosmos/__init__.py @@ -22,6 +22,9 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["pipeline_cosmos2_5_predict"] = [ + "Cosmos2_5_PredictBasePipeline", + ] _import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"] _import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"] _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"] @@ -35,6 +38,9 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * else: + from .pipeline_cosmos2_5_predict import ( + Cosmos2_5_PredictBasePipeline, + ) from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py new file mode 100644 index 000000000000..0f3f62551d35 --- /dev/null +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py @@ -0,0 +1,880 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torchvision +import torchvision.transforms +import torchvision.transforms.functional +from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan, CosmosTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import CosmosPipelineOutput + + +if is_cosmos_guardrail_available(): + from cosmos_guardrail import CosmosSafetyChecker +else: + + class CosmosSafetyChecker: + def __init__(self, *args, **kwargs): + raise ImportError( + "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." + ) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import Cosmos2_5_PredictBasePipeline + >>> from diffusers.utils import export_to_video, load_image, load_video + + >>> model_id = "nvidia/Cosmos-Predict2.5-2B" + >>> pipe = Cosmos2_5_PredictBasePipeline.from_pretrained( + ... model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Common negative prompt reused across modes. + >>> negative_prompt = ( + ... "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + ... "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + ... "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " + ... "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + ... "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + ... "Overall, the video is of poor quality." + ... ) + + >>> # Text2World: generate a 93-frame world video from text only. + >>> prompt = ( + ... "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights " + ... "cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh " + ... "lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet " + ... "reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. " + ... "The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow " + ... "advance of traffic through the frosty city corridor." + ... ) + >>> video = pipe( + ... image=None, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "text2world.mp4", fps=16) + + >>> # Image2World: condition on a single image and generate a 93-frame world video. + >>> prompt = ( + ... "A high-definition video captures the precision of robotic welding in an industrial setting. " + ... "The first frame showcases a robotic arm, equipped with a welding torch, positioned over a large metal structure. " + ... "The welding process is in full swing, with bright sparks and intense light illuminating the scene, creating a vivid " + ... "display of blue and white hues. A significant amount of smoke billows around the welding area, partially obscuring " + ... "the view but emphasizing the heat and activity. The background reveals parts of the workshop environment, including a " + ... "ventilation system and various pieces of machinery, indicating a busy and functional industrial workspace. As the video " + ... "progresses, the robotic arm maintains its steady position, continuing the welding process and moving to its left. " + ... "The welding torch consistently emits sparks and light, and the smoke continues to rise, diffusing slightly as it moves upward. " + ... "The metal surface beneath the torch shows ongoing signs of heating and melting. The scene retains its industrial ambiance, with " + ... "the welding sparks and smoke dominating the visual field, underscoring the ongoing nature of the welding operation." + ... ) + >>> image = load_image( + ... "https://media.githubusercontent.com/media/nvidia-cosmos/cosmos-predict2.5/refs/heads/main/assets/base/robot_welding.jpg" + ... ) + >>> video = pipe( + ... image=image, + ... video=None, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "image2world.mp4", fps=16) + + >>> # Video2World: condition on an input clip and predict a 93-frame world video. + >>> prompt = ( + ... "The video opens with an aerial view of a large-scale sand mining construction operation, showcasing extensive piles " + ... "of brown sand meticulously arranged in parallel rows. A central water channel, fed by a water pipe, flows through the " + ... "middle of these sand heaps, creating ripples and movement as it cascades down. The surrounding area features dense green " + ... "vegetation on the left, contrasting with the sandy terrain, while a body of water is visible in the background on the right. " + ... "As the video progresses, a piece of heavy machinery, likely a bulldozer, enters the frame from the right, moving slowly along " + ... "the edge of the sand piles. This machinery's presence indicates ongoing construction work in the operation. The final frame " + ... "captures the same scene, with the water continuing its flow and the bulldozer still in motion, maintaining the dynamic yet " + ... "steady pace of the construction activity." + ... ) + >>> input_video = load_video( + ... "https://github.com/nvidia-cosmos/cosmos-predict2.5/raw/refs/heads/main/assets/base/sand_mining.mp4" + ... ) + >>> video = pipe( + ... image=None, + ... video=input_video, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... num_frames=93, + ... generator=torch.Generator().manual_seed(1), + ... ).frames[0] + >>> export_to_video(video, "video2world.mp4", fps=16) + + >>> # To produce an image instead of a world (video) clip, set num_frames=1 and + >>> # save the first frame: pipe(..., num_frames=1).frames[0][0]. + ``` +""" + + +class Cosmos2_5_PredictBasePipeline(DiffusionPipeline): + r""" + Pipeline for [Cosmos Predict2.5](https://github.com/nvidia-cosmos/cosmos-predict2.5) base model. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder. Cosmos Predict2.5 uses the [Qwen2.5 + VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) encoder. + tokenizer (`AutoTokenizer`): + Tokenizer associated with the Qwen2.5 VL encoder. + transformer ([`CosmosTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + # We mark safety_checker as optional here to get around some test failures, but it is not really optional + _optional_components = ["safety_checker"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: AutoTokenizer, + transformer: CosmosTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + safety_checker: CosmosSafetyChecker = None, + ): + super().__init__() + + if safety_checker is None: + safety_checker = CosmosSafetyChecker() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + safety_checker=safety_checker, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_mean", None) is not None + else None + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float() + if getattr(self.vae.config, "latents_std", None) is not None + else None + ) + self.latents_mean = latents_mean + self.latents_std = latents_std + + if self.latents_mean is None or self.latents_std is None: + raise ValueError("VAE configuration must define both `latents_mean` and `latents_std`.") + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + + input_ids_batch = [] + + for sample_idx in range(len(prompt)): + conversations = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful assistant who will provide prompts to an image generator.", + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt[sample_idx], + } + ], + }, + ] + input_ids = self.tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=False, + add_vision_id=False, + max_length=max_sequence_length, + truncation=True, + padding="max_length", + ) + input_ids = torch.LongTensor(input_ids) + input_ids_batch.append(input_ids) + + input_ids_batch = torch.stack(input_ids_batch, dim=0) + + outputs = self.text_encoder( + input_ids_batch.to(device), + output_hidden_states=True, + ) + hidden_states = outputs.hidden_states + + normalized_hidden_states = [] + for layer_idx in range(1, len(hidden_states)): + normalized_state = (hidden_states[layer_idx] - hidden_states[layer_idx].mean(dim=-1, keepdim=True)) / ( + hidden_states[layer_idx].std(dim=-1, keepdim=True) + 1e-8 + ) + normalized_hidden_states.append(normalized_state) + + prompt_embeds = torch.cat(normalized_hidden_states, dim=-1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_prompt_embeds( + prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_prompt_embeds( + prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + # Modified from diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2VideoToWorldPipeline.prepare_latents and + # diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.Cosmos2TextToImagePipeline.prepare_latents + def prepare_latents( + self, + video: Optional[torch.Tensor], + batch_size: int, + num_channels_latents: int = 16, + height: int = 704, + width: int = 1280, + num_frames_in: int = 93, + num_frames_out: int = 93, + do_classifier_free_guidance: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + B = batch_size + C = num_channels_latents + T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + H = height // self.vae_scale_factor_spatial + W = width // self.vae_scale_factor_spatial + shape = (B, C, T, H, W) + + if num_frames_in == 0: + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) + cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) + + cond_latents = torch.zeros_like(latents) + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + else: + if video is None: + raise ValueError("`video` must be provided when `num_frames_in` is greater than 0.") + needs_preprocessing = not (isinstance(video, torch.Tensor) and video.ndim == 5 and video.shape[1] == 3) + if needs_preprocessing: + video = self.video_processor.preprocess_video(video, height, width) + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + cond_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i]) + for i in range(batch_size) + ] + else: + cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + cond_latents = torch.cat(cond_latents, dim=0).to(dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=dtype) + latents_std = self.latents_std.to(device=device, dtype=dtype) + cond_latents = (cond_latents - latents_mean) / latents_std + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + padding_shape = (B, 1, T, H, W) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + num_cond_latent_frames = (num_frames_in - 1) // self.vae_scale_factor_temporal + 1 + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding + + return ( + latents, + cond_latents, + cond_mask, + cond_indicator, + ) + + # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput | None = None, + video: List[PipelineImageInput] | None = None, + prompt: Union[str, List[str]] | None = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 704, + width: int = 1280, + num_frames: int = 93, + num_inference_steps: int = 36, + guidance_scale: float = 7.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + conditional_frame_timestep: float = 0.1, + num_latent_conditional_frames: int = 2, + ): + r""" + The call function to the pipeline for generation. Supports three modes: + + - **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip. + - **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame. + - **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip. + + Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the + above in "*2Image mode"). + + Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt). + + Args: + image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional single image for Image2World conditioning. Must be `None` when `video` is provided. + video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): + Optional input video for Video2World conditioning. Must be `None` when `image` is provided. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. + height (`int`, defaults to `704`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `93`): + Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame. + num_inference_steps (`int`, defaults to `35`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `7.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `512`): + The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If + the prompt is shorter than this length, it will be padded. + num_latent_conditional_frames (`int`, defaults to `2`): + Number of latent conditional frames to use for Video2World conditioning. The number of pixel frames + extracted from the input video is calculated as `4 * (num_latent_conditional_frames - 1) + 1`. Set to 1 + for Image2World-like behavior (single frame conditioning). + + Examples: + + Returns: + [`~CosmosPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if self.safety_checker is None: + raise ValueError( + f"You have disabled the safety checker for {self.__class__}. This is in violation of the " + "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " + f"Please ensure that you are compliant with the license agreement." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) + + self._guidance_scale = guidance_scale + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + if self.safety_checker is not None: + self.safety_checker.to(device) + if prompt is not None: + prompt_list = [prompt] if isinstance(prompt, str) else prompt + for p in prompt_list: + if not self.safety_checker.check_text_safety(p): + raise ValueError( + f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " + f"prompt abides by the NVIDIA Open Model License Agreement." + ) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + vae_dtype = self.vae.dtype + transformer_dtype = self.transformer.dtype + + num_frames_in = None + if image is not None: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") + + image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) + video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) + video = video.unsqueeze(0) + num_frames_in = 1 + elif video is None: + video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) + num_frames_in = 0 + else: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") + + if num_latent_conditional_frames not in [1, 2]: + raise ValueError( + f"num_latent_conditional_frames must be 1 or 2, but got {num_latent_conditional_frames}" + ) + + frames_to_extract = 4 * (num_latent_conditional_frames - 1) + 1 + + total_input_frames = len(video) + + if total_input_frames < frames_to_extract: + raise ValueError( + f"Input video has only {total_input_frames} frames but Video2World requires at least " + f"{frames_to_extract} frames for conditioning." + ) + + num_frames_in = frames_to_extract + + assert video is not None + video = self.video_processor.preprocess_video(video, height, width) + + # For Video2World: extract last frames_to_extract frames from input, then pad + if image is None and num_frames_in > 0 and num_frames_in < video.shape[2]: + video = video[:, :, -num_frames_in:, :, :] + + num_frames_out = num_frames + + if video.shape[2] < num_frames_out: + n_pad_frames = num_frames_out - video.shape[2] + last_frame = video[:, :, -1:, :, :] # [B, C, T==1, H, W] + pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W] + video = torch.cat((video, pad_frames), dim=2) + + assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + + video = video.to(device=device, dtype=vae_dtype) + + num_channels_latents = self.transformer.config.in_channels - 1 + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames_in=num_frames_in, + num_frames_out=num_frames, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + cond_mask = cond_mask.to(transformer_dtype) + + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + gt_velocity = (latents - cond_latent) * cond_mask + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) + ) + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t + noise_pred = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + noise_pred_neg = self.transformer( + hidden_states=in_latents, + condition_mask=cond_mask, + timestep=in_timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents_mean = self.latents_mean.to(latents.device, latents.dtype) + latents_std = self.latents_std.to(latents.device, latents.dtype) + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] + video = self._match_num_frames(video, num_frames) + + assert self.safety_checker is not None + self.safety_checker.to(device) + video = self.video_processor.postprocess_video(video, output_type="np") + video = (video * 255).astype(np.uint8) + video_batch = [] + for vid in video: + vid = self.safety_checker.check_video_safety(vid) + video_batch.append(vid) + video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 + video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CosmosPipelineOutput(frames=video) + + def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor: + if target_num_frames <= 0 or video.shape[2] == target_num_frames: + return video + + frames_per_latent = max(self.vae_scale_factor_temporal, 1) + video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2) + + current_frames = video.shape[2] + if current_frames < target_num_frames: + pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1) + video = torch.cat([video, pad], dim=2) + elif current_frames > target_num_frames: + video = video[:, :, :target_num_frames] + + return video diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py index 66490c2be159..d5ee60962285 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py @@ -49,6 +49,14 @@ def __init__(self, *args, **kwargs): logger = logging.get_logger(__name__) # pylint: disable=invalid-name +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) EXAMPLE_DOC_STRING = """ Examples: @@ -300,7 +308,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py index 23a74ad00f93..b2c44d0d9972 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py @@ -50,6 +50,14 @@ def __init__(self, *args, **kwargs): logger = logging.get_logger(__name__) # pylint: disable=invalid-name +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) EXAMPLE_DOC_STRING = """ Examples: @@ -319,7 +327,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py index f0aa1ecf0e0f..676b7c7a72b0 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py @@ -49,6 +49,14 @@ def __init__(self, *args, **kwargs): logger = logging.get_logger(__name__) # pylint: disable=invalid-name +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) EXAMPLE_DOC_STRING = """ Examples: @@ -285,7 +293,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py index cd5a734cc311..df507c3a4f90 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py @@ -50,6 +50,14 @@ def __init__(self, *args, **kwargs): logger = logging.get_logger(__name__) # pylint: disable=invalid-name +DEFAULT_NEGATIVE_PROMPT = ( + "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " + "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " + "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, " + "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " + "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " + "Overall, the video is of poor quality." +) EXAMPLE_DOC_STRING = """ Examples: @@ -331,7 +339,7 @@ def encode_prompt( prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt = negative_prompt if negative_prompt is not None else DEFAULT_NEGATIVE_PROMPT negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py index 92239c0d32f0..86c4d6812130 100755 --- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py +++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py @@ -666,12 +666,18 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, mu=1 + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 ) else: - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py index f74a11f87d75..b28a2c9fb273 100755 --- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py @@ -810,12 +810,18 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, mu=1 + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 ) else: - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) timesteps = self.scheduler.timesteps # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py index b16ef92d8e6b..ec394315ee93 100755 --- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py @@ -956,12 +956,18 @@ def __call__( ) # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, mu=1 + self.scheduler, num_inference_steps, timestep_device, timesteps, mu=1 ) else: - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 5041e352f73d..9562722dbee3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -876,10 +876,15 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 848d7bd39254..77f971d57a80 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -829,10 +829,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 262345c75afc..e1bbc6735051 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -810,10 +810,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index 6915a83a7ca7..b02e74d3b2d6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -1013,10 +1013,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 507ec687347c..78de4f617f84 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1002,10 +1002,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 582c7bbad84e..5bf593258f49 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -873,10 +873,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index f7f34ef231af..a1e1f5f5e9e5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1020,10 +1020,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 5cb9c82204b2..8ec9871d2579 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -932,10 +932,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index ab9140dae921..5166a6497e01 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -940,10 +940,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 3bfe82cf4382..64a81fb0699f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -1015,10 +1015,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py index d986c9a63011..f6e1d5206630 100644 --- a/src/diffusers/pipelines/flux2/__init__.py +++ b/src/diffusers/pipelines/flux2/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_flux2"] = ["Flux2Pipeline"] + _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -31,6 +32,7 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_flux2 import Flux2Pipeline + from .pipeline_flux2_klein import Flux2KleinPipeline else: import sys diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index b54a43dd89a5..c01b7137e086 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -725,8 +725,8 @@ def guidance_scale(self): return self._guidance_scale @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs + def attention_kwargs(self): + return self._attention_kwargs @property def num_timesteps(self): @@ -975,7 +975,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, # B, text_seq_len, 4 img_ids=latent_image_ids, # B, image_seq_len, 4 - joint_attention_kwargs=self._attention_kwargs, + joint_attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py new file mode 100644 index 000000000000..efb0aebf8593 --- /dev/null +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py @@ -0,0 +1,918 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2KleinPipeline + + >>> pipe = Flux2KleinPipeline.from_pretrained( + ... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=4.0).images[0] + >>> image.save("flux2_output.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + The Flux2 Klein pipeline for text-to-image generation. + + Reference: + [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3ForCausalLM`]): + [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM) + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + transformer: Flux2Transformer2DModel, + is_distilled: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + + self.register_to_config(is_distilled=is_distilled) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + @staticmethod + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (9, 18, 27), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and not self.config.is_distilled + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[Union[str, List[str]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (9, 18, 27), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, + `guidance_scale` is ignored. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. + If not provided, will be generated from "". + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + guidance_scale=guidance_scale, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + if self.do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None and isinstance(prompt, list): + negative_prompt = [negative_prompt] * len(prompt) + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1) :] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + if output_type == "latent": + image = latents + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py new file mode 100644 index 000000000000..140b9cc760cc --- /dev/null +++ b/src/diffusers/pipelines/glm_image/__init__.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]} + +# Import transformers components so they can be resolved during pipeline loading + +if is_transformers_available() and is_transformers_version(">=", "4.57.4"): + try: + from transformers import GlmImageForConditionalGeneration, GlmImageProcessor + + _additional_imports["GlmImageForConditionalGeneration"] = GlmImageForConditionalGeneration + _additional_imports["GlmImageProcessor"] = GlmImageProcessor + except ImportError: + pass + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_glm_image import GlmImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py new file mode 100644 index 000000000000..589b3be47b2c --- /dev/null +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -0,0 +1,1051 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import ByT5Tokenizer, PreTrainedModel, ProcessorMixin, T5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, GlmImageTransformer2DModel +from ...models.transformers.transformer_glm_image import GlmImageKVCache +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, is_transformers_version, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import GlmImagePipelineOutput + + +# Because it's not released in stable as of 13/01/2026. So this is just a proxy. +GlmImageProcessor = ProcessorMixin +GlmImageForConditionalGeneration = PreTrainedModel +if is_transformers_version(">=", "5.0.0.dev0"): + from transformers import GlmImageForConditionalGeneration, GlmImageProcessor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import GlmImagePipeline + + >>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using GLM-Image. + + This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion + transformer) model for image decoding. + + Args: + tokenizer (`PreTrainedTokenizer`): + Tokenizer for the text encoder. + processor (`AutoProcessor`): + Processor for the AR model to handle chat templates and tokenization. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder for glyph embeddings. + vision_language_encoder ([`GlmImageForConditionalGeneration`]): + The AR model that generates image tokens from text prompts. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`GlmImageTransformer2DModel`]): + A text conditioned transformer to denoise the encoded image latents (DiT). + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "vision_language_encoder->text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: ByT5Tokenizer, + processor: GlmImageProcessor, + text_encoder: T5EncoderModel, + vision_language_encoder: GlmImageForConditionalGeneration, + vae: AutoencoderKL, + transformer: GlmImageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + processor=processor, + text_encoder=text_encoder, + vision_language_encoder=vision_language_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") + and self.transformer is not None + and hasattr(self.transformer.config, "sample_size") + else 128 + ) + + @staticmethod + def _compute_generation_params( + image_grid_thw, + is_text_to_image: bool, + ): + grid_sizes = [] + grid_hw = [] + + for i in range(image_grid_thw.shape[0]): + t, h, w = image_grid_thw[i].tolist() + grid_sizes.append(int(h * w)) + grid_hw.append((int(h), int(w))) + + if not is_text_to_image: + max_new_tokens = grid_sizes[-1] + 1 + large_image_start_offset = 0 + target_grid_h, target_grid_w = grid_hw[-1] + else: + total_tokens = sum(grid_sizes) + max_new_tokens = total_tokens + 1 + large_image_start_offset = sum(grid_sizes[1:]) + target_grid_h, target_grid_w = grid_hw[0] + return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w + + @staticmethod + def _extract_large_image_tokens( + outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + ) -> torch.Tensor: + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + @staticmethod + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + token_ids = token_ids.view(1, 1, token_h, token_w) + token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( + dtype=torch.long + ) + token_ids = token_ids.view(1, -1) + return token_ids + + @staticmethod + def _validate_and_normalize_images( + image: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]], + batch_size: int, + ) -> Optional[List[List[PIL.Image.Image]]]: + """ + Validate and normalize image inputs to List[List[PIL.Image]]. + + Rules: + - batch_size > 1: Only accepts List[List[PIL.Image]], each sublist must have equal length + - batch_size == 1: Accepts List[PIL.Image] for legacy compatibility (converted to [[img1, img2, ...]]) + - Other formats raise ValueError + + Args: + image: Input images in various formats + batch_size: Number of prompts in the batch + + Returns: + Normalized images as List[List[PIL.Image]], or None if no images provided + """ + if image is None or len(image) == 0: + return None + + first_element = image[0] + + if batch_size == 1: + # Legacy format: List[PIL.Image] -> [[img1, img2, ...]] + if not isinstance(first_element, (list, tuple)): + return [list(image)] + # Already in List[List[PIL.Image]] format + if len(image) != 1: + raise ValueError( + f"For batch_size=1 with List[List[PIL.Image]] format, expected 1 image list, got {len(image)}." + ) + return [list(image[0])] + + # batch_size > 1: must be List[List[PIL.Image]] + if not isinstance(first_element, (list, tuple)): + raise ValueError( + f"For batch_size > 1, images must be List[List[PIL.Image]] format. " + f"Got List[{type(first_element).__name__}] instead. " + f"Each prompt requires its own list of condition images." + ) + + if len(image) != batch_size: + raise ValueError(f"Number of image lists ({len(image)}) must match batch size ({batch_size}).") + + # Validate homogeneous: all sublists must have same length + num_input_images_per_prompt = len(image[0]) + for idx, imgs in enumerate(image): + if len(imgs) != num_input_images_per_prompt: + raise ValueError( + f"All prompts must have the same number of condition images. " + f"Prompt 0 has {num_input_images_per_prompt} images, but prompt {idx} has {len(imgs)} images." + ) + + return [list(imgs) for imgs in image] + + def generate_prior_tokens( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + image: Optional[List[List[PIL.Image.Image]]] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + ): + """ + Generate prior tokens for the DiT model using the AR model. + + Args: + prompt: Single prompt or list of prompts + height: Target image height + width: Target image width + image: Normalized image input as List[List[PIL.Image]]. Should be pre-validated + using _validate_and_normalize_images() before calling this method. + device: Target device + generator: Random generator for reproducibility + + Returns: + Tuple of: + - prior_token_ids: Tensor of shape (batch_size, num_tokens) with upsampled prior tokens + - prior_token_image_ids_per_sample: List of tensors, one per sample. Each tensor contains + the upsampled prior token ids for all condition images in that sample. None for t2i. + - source_image_grid_thw_per_sample: List of tensors, one per sample. Each tensor has shape + (num_condition_images, 3) with upsampled grid info. None for t2i. + """ + device = device or self._execution_device + + # Normalize prompt to list format + prompt_list = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt_list) + + # Image is already normalized by _validate_and_normalize_images(): None or List[List[PIL.Image]] + is_text_to_image = image is None + # Build messages for each sample in the batch + all_messages = [] + for idx, p in enumerate(prompt_list): + content = [] + if not is_text_to_image: + for img in image[idx]: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": p}) + all_messages.append([{"role": "user", "content": content}]) + # Process with the processor (supports batch with left padding) + inputs = self.processor.apply_chat_template( + all_messages, + tokenize=True, + padding=True if batch_size > 1 else False, + target_h=height, + target_w=width, + return_dict=True, + return_tensors="pt", + ).to(device) + + image_grid_thw = inputs.get("image_grid_thw") + images_per_sample = inputs.get("images_per_sample") + + # Determine number of condition images and grids per sample + num_condition_images = 0 if is_text_to_image else len(image[0]) + if images_per_sample is not None: + num_grids_per_sample = images_per_sample[0].item() + else: + # Fallback for batch_size=1: total grids is for single sample + num_grids_per_sample = image_grid_thw.shape[0] + + # Compute generation params (same for all samples in homogeneous batch) + first_sample_grids = image_grid_thw[:num_grids_per_sample] + max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params( + image_grid_thw=first_sample_grids, is_text_to_image=is_text_to_image + ) + + # Generate source image tokens (prior_token_image_ids) for i2i mode + prior_token_image_ids = None + source_image_grid_thw = None + if not is_text_to_image: + # Extract source grids by selecting condition image indices (skip target grids) + # Grid order from processor: [s0_cond1, s0_cond2, ..., s0_target, s1_cond1, s1_cond2, ..., s1_target, ...] + # We need indices: [0, 1, ..., num_condition_images-1, num_grids_per_sample, num_grids_per_sample+1, ...] + source_indices = [] + for sample_idx in range(batch_size): + base = sample_idx * num_grids_per_sample + source_indices.extend(range(base, base + num_condition_images)) + source_grids = image_grid_thw[source_indices] + + if len(source_grids) > 0: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], source_grids + ).pooler_output + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids_d32 = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, source_grids + ) + # Upsample each source image's prior tokens to match VAE/DiT resolution + split_sizes = source_grids.prod(dim=-1).tolist() + prior_ids_per_source = torch.split(prior_token_image_ids_d32, split_sizes) + upsampled_prior_ids = [] + for i, prior_ids in enumerate(prior_ids_per_source): + t, h, w = source_grids[i].tolist() + upsampled = self._upsample_token_ids(prior_ids, int(h), int(w)) + upsampled_prior_ids.append(upsampled.squeeze(0)) + prior_token_image_ids = torch.cat(upsampled_prior_ids, dim=0) + # Upsample grid dimensions for later splitting + upsampled_grids = source_grids.clone() + upsampled_grids[:, 1] = upsampled_grids[:, 1] * 2 + upsampled_grids[:, 2] = upsampled_grids[:, 2] * 2 + source_image_grid_thw = upsampled_grids + + # Generate with AR model + # Set torch random seed from generator for reproducibility + # (transformers generate() doesn't accept generator parameter) + if generator is not None: + seed = generator.initial_seed() + torch.manual_seed(seed) + if device is not None and device.type == "cuda": + torch.cuda.manual_seed(seed) + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + # Extract and upsample prior tokens for each sample + # For left-padded inputs, generated tokens start after the padded input sequence + all_prior_token_ids = [] + max_input_length = inputs["input_ids"].shape[-1] + for idx in range(batch_size): + # For left-padded sequences, generated tokens start at max_input_length + # (padding is on the left, so all sequences end at the same position) + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs[idx : idx + 1], max_input_length, large_image_offset, token_h * token_w + ) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + all_prior_token_ids.append(prior_token_ids) + prior_token_ids = torch.cat(all_prior_token_ids, dim=0) + + # Split prior_token_image_ids and source_image_grid_thw into per-sample lists for easier consumption + prior_token_image_ids_per_sample = None + source_image_grid_thw_per_sample = None + if prior_token_image_ids is not None and source_image_grid_thw is not None: + # Split grids: each sample has num_condition_images grids + source_image_grid_thw_per_sample = list(torch.split(source_image_grid_thw, num_condition_images)) + # Split prior_token_image_ids: tokens per sample may vary due to different image sizes + tokens_per_image = source_image_grid_thw.prod(dim=-1).tolist() + tokens_per_sample = [] + for i in range(batch_size): + start_idx = i * num_condition_images + end_idx = start_idx + num_condition_images + tokens_per_sample.append(sum(tokens_per_image[start_idx:end_idx])) + prior_token_image_ids_per_sample = list(torch.split(prior_token_image_ids, tokens_per_sample)) + + return prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample + + def get_glyph_texts(self, prompt): + """Extract glyph texts from prompt(s). Returns a list of lists for batch processing.""" + if isinstance(prompt, str): + prompt = [prompt] + all_ocr_texts = [] + for p in prompt: + ocr_texts = ( + re.findall(r"'([^']*)'", p) + + re.findall(r"\u201c([^\u201c\u201d]*)\u201d", p) + + re.findall(r'"([^"]*)"', p) + + re.findall(r"「([^「」]*)」", p) + ) + all_ocr_texts.append(ocr_texts) + return all_ocr_texts + + def _get_glyph_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """Get glyph embeddings for each prompt in the batch.""" + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + # get_glyph_texts now returns a list of lists (one per prompt) + all_glyph_texts = self.get_glyph_texts(prompt) + + all_glyph_embeds = [] + for glyph_texts in all_glyph_texts: + if len(glyph_texts) == 0: + glyph_texts = [""] + input_ids = self.tokenizer( + glyph_texts, + max_length=max_sequence_length, + truncation=True, + ).input_ids + input_ids = [ + [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids + ] + max_length = max(len(input_ids_) for input_ids_ in input_ids) + attention_mask = torch.tensor( + [[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], + device=device, + ) + input_ids = torch.tensor( + [ + input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) + for input_ids_ in input_ids + ], + device=device, + ) + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + all_glyph_embeds.append(glyph_embeds) + + # Pad to same sequence length and stack (use left padding to match transformers) + max_seq_len = max(emb.size(1) for emb in all_glyph_embeds) + padded_embeds = [] + for emb in all_glyph_embeds: + if emb.size(1) < max_seq_len: + pad = torch.zeros(emb.size(0), max_seq_len - emb.size(1), emb.size(2), device=device, dtype=emb.dtype) + emb = torch.cat([pad, emb], dim=1) # left padding + padded_embeds.append(emb) + + glyph_embeds = torch.cat(padded_embeds, dim=0) + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 2048, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `2048`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) + + # Repeat embeddings for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + # For GLM-Image, negative_prompt must be "" instead of None + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) + + if num_images_per_prompt > 1: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + prior_token_ids=None, + prior_token_image_ids=None, + source_image_grid_thw=None, + image=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * self.transformer.config.patch_size * 2) != 0 + or width is not None + and width % (self.transformer.config.patch_size * 2) != 0 + ): + # GLM-Image uses 32× downsampling, so the image dimensions must be multiples of 32. + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 4} but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if prompt is not None and prior_token_ids is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prior_token_ids is None: + raise ValueError( + "Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined." + ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + # Validate prior token inputs: for i2i mode, all three must be provided together + # For t2i mode, only prior_token_ids is needed (prior_token_image_ids and source_image_grid_thw should be None) + prior_image_inputs = [prior_token_image_ids, source_image_grid_thw] + num_prior_image_inputs = sum(x is not None for x in prior_image_inputs) + if num_prior_image_inputs > 0 and num_prior_image_inputs < len(prior_image_inputs): + raise ValueError( + "`prior_token_image_ids` and `source_image_grid_thw` must be provided together for i2i mode. " + f"Got prior_token_image_ids={prior_token_image_ids is not None}, " + f"source_image_grid_thw={source_image_grid_thw is not None}." + ) + if num_prior_image_inputs > 0 and prior_token_ids is None: + raise ValueError( + "`prior_token_ids` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided." + ) + if num_prior_image_inputs > 0 and image is None: + raise ValueError( + "`image` must be provided when `prior_token_image_ids` and `source_image_grid_thw` are provided " + "for i2i mode, as the images are needed for VAE encoding to build the KV cache." + ) + + if prior_token_ids is not None and prompt_embeds is None: + raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[ + Union[ + torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray] + ] + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.5, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prior_token_ids: Optional[torch.FloatTensor] = None, + prior_token_image_ids: Optional[List[torch.Tensor]] = None, + source_image_grid_thw: Optional[List[torch.Tensor]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 2048, + ) -> Union[GlmImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. Must contain shape info in the format 'H + W' where H and W are token dimensions (d32). Example: "A beautiful sunset36 24" + generates a 1152x768 image. + image: Optional condition images for image-to-image generation. + height (`int`, *optional*): + The height in pixels. If not provided, derived from prompt shape info. + width (`int`, *optional*): + The width in pixels. If not provided, derived from prompt shape info. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps for DiT. + guidance_scale (`float`, *optional*, defaults to `1.5`): + Guidance scale for classifier-free guidance. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + Random generator for reproducibility. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format: "pil", "np", or "latent". + + Examples: + + Returns: + [`GlmImagePipelineOutput`] or `tuple`: Generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + prior_token_ids, + prior_token_image_ids, + source_image_grid_thw, + image, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 2. Validate and normalize image format + normalized_image = self._validate_and_normalize_images(image, batch_size) + + # 3. Generate prior tokens (batch mode) + # Get a single generator for AR model (use first if list provided) + ar_generator = generator[0] if isinstance(generator, list) else generator + if prior_token_ids is None: + prior_token_ids, prior_token_image_ids_per_sample, source_image_grid_thw_per_sample = ( + self.generate_prior_tokens( + prompt=prompt, + image=normalized_image, + height=height, + width=width, + device=device, + generator=ar_generator, + ) + ) + else: + # User provided prior_token_ids directly (from generate_prior_tokens) + prior_token_image_ids_per_sample = prior_token_image_ids + source_image_grid_thw_per_sample = source_image_grid_thw + + # 4. Preprocess images for VAE encoding + preprocessed_images = None + if normalized_image is not None: + preprocessed_images = [] + for prompt_images in normalized_image: + prompt_preprocessed = [] + for img in prompt_images: + image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] + multiple_of = self.vae_scale_factor * self.transformer.config.patch_size + image_height = (image_height // multiple_of) * multiple_of + image_width = (image_width // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width) + prompt_preprocessed.append(img) + height = height or image_height + width = width or image_width + preprocessed_images.append(prompt_preprocessed) + + # 5. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.dtype, + ) + + # 6. Prepare latents and (optional) image kv cache + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) + + if normalized_image is not None: + kv_caches.set_mode("write") + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1) + + latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype) + latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype) + + # Process each sample's condition images + for prompt_idx in range(batch_size): + prompt_images = preprocessed_images[prompt_idx] + prompt_prior_ids = prior_token_image_ids_per_sample[prompt_idx] + prompt_grid_thw = source_image_grid_thw_per_sample[prompt_idx] + + # Split this sample's prior_token_image_ids by each image's token count + split_sizes = prompt_grid_thw.prod(dim=-1).tolist() + prior_ids_per_image = torch.split(prompt_prior_ids, split_sizes) + # Process each condition image for this sample + for condition_image, condition_image_prior_token_id in zip(prompt_images, prior_ids_per_image): + condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype) + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], + prior_token_id=condition_image_prior_token_id, + prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=device), + target_size=torch.tensor([condition_image.shape[-2:]], device=device), + crop_coords=torch.zeros((1, 2), device=device), + attention_kwargs=attention_kwargs, + kv_caches=kv_caches, + ) + # Move to next sample's cache slot + kv_caches.next_sample() + + # 7. Prepare additional timestep conditions + target_size = (height, width) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1] + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # 8. Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # Repeat prior_token_ids for num_images_per_prompt + if num_images_per_prompt > 1: + prior_token_ids = prior_token_ids.repeat_interleave(num_images_per_prompt, dim=0) + prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + + timestep = t.expand(latents.shape[0]) - 1 + + if prior_token_image_ids_per_sample is not None: + kv_caches.set_mode("read") + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_ids, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + kv_caches=kv_caches, + )[0].float() + + # perform guidance + if self.do_classifier_free_guidance: + if prior_token_image_ids_per_sample is not None: + kv_caches.set_mode("skip") + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_ids, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + kv_caches=kv_caches, + )[0].float() + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + kv_caches.clear() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return GlmImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/glm_image/pipeline_output.py b/src/diffusers/pipelines/glm_image/pipeline_output.py new file mode 100644 index 000000000000..aec5a5454ea8 --- /dev/null +++ b/src/diffusers/pipelines/glm_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class GlmImagePipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index b6af23bca8fd..b41d9772a7cc 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -53,7 +53,6 @@ >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> from diffusers import HiDreamImagePipeline - >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") >>> text_encoder_4 = LlamaForCausalLM.from_pretrained( ... "meta-llama/Meta-Llama-3.1-8B-Instruct", @@ -965,14 +964,18 @@ def __call__( # 5. Prepare timesteps mu = calculate_shift(self.transformer.max_seq) scheduler_kwargs = {"mu": mu} + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device if isinstance(self.scheduler, UniPCMultistepScheduler): - self.scheduler.set_timesteps(num_inference_steps, device=device) # , shift=math.exp(mu)) + self.scheduler.set_timesteps(num_inference_steps, device=timestep_device) # , shift=math.exp(mu)) timesteps = self.scheduler.timesteps else: timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, **scheduler_kwargs, ) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py index b50a6ae3ed27..6bb7a4344da5 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py @@ -728,7 +728,13 @@ def __call__( # 4. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) # 5. Prepare latent variables vae_dtype = self.vae.dtype diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 5c8e295eaf4c..42ab090f1cba 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -683,7 +683,13 @@ def __call__( # 4. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py index 9e9f20c79eba..8c555eabba11 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -852,6 +852,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + if self.transformer.config.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + else: + timestep_r = None + # Step 1: Collect model inputs needed for the guidance method # conditional inputs should always be first element in the tuple guider_inputs = { @@ -893,6 +902,7 @@ def __call__( hidden_states=latent_model_input, image_embeds=image_embeds, timestep=timestep, + timestep_r=timestep_r, attention_kwargs=self.attention_kwargs, return_dict=False, **cond_kwargs, diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index e2f935aaf4b9..052c7b473915 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -17,7 +17,7 @@ import numpy as np import torch -from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -169,7 +169,7 @@ class HunyuanDiTPipeline(DiffusionPipeline): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -204,7 +204,7 @@ def __init__( feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, - tokenizer_2: Optional[MT5Tokenizer] = None, + tokenizer_2: Optional[T5Tokenizer] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py index 7c8468bcb109..3c7442afcaae 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py @@ -877,8 +877,12 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 10a7962c258c..8c3adf33b845 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -1034,8 +1034,12 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 59f733a498ed..c28e358c51b6 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -881,10 +881,14 @@ def __call__( image = self.image_processor.preprocess(image) # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, timesteps, original_inference_steps=original_inference_steps, strength=strength, diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index e463884618f5..bc71d7bd171a 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -815,8 +815,16 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + original_inference_steps=original_inference_steps, ) # 5. Prepare latent variable diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 4d42a7049ec9..7fde18e4fbbb 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -767,7 +767,13 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) self._num_timesteps = len(timesteps) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/longcat_image/__init__.py b/src/diffusers/pipelines/longcat_image/__init__.py new file mode 100644 index 000000000000..e4bb0e5819c8 --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/__init__.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_longcat_image"] = ["LongCatImagePipeline"] + _import_structure["pipeline_longcat_image_edit"] = ["LongCatImageEditPipeline"] + _import_structure["pipeline_output"] = ["LongCatImagePipelineOutput"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_longcat_image import LongCatImagePipeline + from .pipeline_longcat_image_edit import LongCatImageEditPipeline + from .pipeline_output import LongCatImagePipelineOutput + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py new file mode 100644 index 000000000000..ca28422f9ca0 --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py @@ -0,0 +1,666 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import re +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LongCatImagePipelineOutput +from .system_messages import SYSTEM_PROMPT_EN, SYSTEM_PROMPT_ZH + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LongCatImagePipeline + + >>> pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。" + >>> image = pipe( + ... prompt, + ... height=768, + ... width=1344, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... enable_cfg_renorm=True, + ... ).images[0] + >>> image.save("longcat_image.png") + ``` +""" + + +def get_prompt_language(prompt): + pattern = re.compile(r"[\u4e00-\u9fff]") + if bool(pattern.search(prompt)): + return "zh" + return "en" + + +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": + assert num_token + if height or width: + print('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + The pipeline for text-to-image generation. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def rewire_prompt(self, prompt, device): + prompt = [prompt] if isinstance(prompt, str) else prompt + all_text = [] + for each_prompt in prompt: + language = get_prompt_language(each_prompt) + if language == "zh": + question = SYSTEM_PROMPT_ZH + f"\n用户输入为:{each_prompt}\n改写后的prompt为:" + else: + question = SYSTEM_PROMPT_EN + f"\nUser Input: {each_prompt}\nRewritten prompt:" + message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": question}, + ], + } + ] + # Preparation for inference + text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + all_text.append(text) + + inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device) + + generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length) + generated_ids.to(device) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = self.text_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + rewrite_prompt = output_text + return rewrite_prompt + + def _encode_prompt(self, prompt: List[str]): + batch_all_tokens = [] + + for each_prompt in prompt: + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(each_prompt): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(all_tokens)}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + batch_all_tokens.append(all_tokens) + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": batch_all_tokens}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + prefix_tokens = self.tokenizer(self.prompt_template_encode_prefix, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + prefix_len = len(prefix_tokens) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + batch_size = text_tokens_and_mask.input_ids.size(0) + + prefix_tokens_batch = prefix_tokens.unsqueeze(0).expand(batch_size, -1) + suffix_tokens_batch = suffix_tokens.unsqueeze(0).expand(batch_size, -1) + prefix_mask_batch = prefix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + suffix_mask_batch = suffix_tokens_mask.unsqueeze(0).expand(batch_size, -1) + + input_ids = torch.cat((prefix_tokens_batch, text_tokens_and_mask.input_ids, suffix_tokens_batch), dim=-1) + attention_mask = torch.cat((prefix_mask_batch, text_tokens_and_mask.attention_mask, suffix_mask_batch), dim=-1) + + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: Optional[int] = 1, + prompt_embeds: Optional[torch.Tensor] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds.to(self.device), text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(self.tokenizer_max_length, self.tokenizer_max_length), + height=height // 2, + width=width // 2, + ).to(device) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device) + latents = latents.to(dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + enable_cfg_renorm: Optional[bool] = True, + cfg_renorm_min: Optional[float] = 0.0, + enable_prompt_rewrite: Optional[bool] = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + enable_cfg_renorm: Whether to enable cfg_renorm. Enabling cfg_renorm will improve image quality, + but it may lead to a decrease in the stability of some image outputs.. + cfg_renorm_min: The minimum value of the cfg_renorm_scale range (0-1). + cfg_renorm_min = 1.0, renorm has no effect, while cfg_renorm_min=0.0, the renorm range is larger. + enable_prompt_rewrite: whether to enable prompt rewrite. + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if enable_prompt_rewrite: + prompt = self.rewire_prompt(prompt, device) + logger.info(f"Rewrite prompt {prompt}!") + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if enable_cfg_renorm: + cond_norm = torch.norm(noise_pred_text, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + scale = (cond_norm / (noise_norm + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + noise_pred = noise_pred * scale + else: + noise_pred = noise_pred_text + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py new file mode 100644 index 000000000000..e55a2a47f343 --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py @@ -0,0 +1,727 @@ +# Copyright 2025 MeiTuan LongCat-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import math +import re +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import LongCatImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import LongCatImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import LongCatImageEditPipeline + + >>> pipe = LongCatImageEditPipeline.from_pretrained( + ... "meituan-longcat/LongCat-Image-Edit", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "change the cat to dog." + >>> input_image = Image.open("test.jpg").convert("RGB") + >>> image = pipe( + ... input_image, + ... prompt, + ... num_inference_steps=50, + ... guidance_scale=4.5, + ... generator=torch.Generator("cpu").manual_seed(43), + ... ).images[0] + >>> image.save("longcat_image_edit.png") + ``` +""" + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.split_quotation +def split_quotation(prompt, quote_pairs=None): + """ + Implement a regex-based string splitting algorithm that identifies delimiters defined by single or double quote + pairs. Examples:: + >>> prompt_en = "Please write 'Hello' on the blackboard for me." >>> print(split_quotation(prompt_en)) >>> # + output: [('Please write ', False), ("'Hello'", True), (' on the blackboard for me.', False)] + """ + word_internal_quote_pattern = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + matches_word_internal_quote_pattern = word_internal_quote_pattern.findall(prompt) + mapping_word_internal_quote = [] + + for i, word_src in enumerate(set(matches_word_internal_quote_pattern)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping_word_internal_quote.append([word_src, word_tgt]) + + if quote_pairs is None: + quote_pairs = [("'", "'"), ('"', '"'), ("‘", "’"), ("“", "”")] + pattern = "|".join([re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) for q1, q2 in quote_pairs]) + parts = re.split(f"({pattern})", prompt) + + result = [] + for part in parts: + for word_src, word_tgt in mapping_word_internal_quote: + part = part.replace(word_tgt, word_src) + if re.match(pattern, part): + if len(part): + result.append((part, True)) + else: + if len(part): + result.append((part, False)) + return result + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.prepare_pos_ids +def prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=None, height=None, width=None): + if type == "text": + assert num_token + if height or width: + print('Warning: The parameters of height and width will be ignored in "text" type.') + pos_ids = torch.zeros(num_token, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = torch.arange(num_token) + start[0] + pos_ids[..., 2] = torch.arange(num_token) + start[1] + elif type == "image": + assert height and width + if num_token: + print('Warning: The parameter of num_token will be ignored in "image" type.') + pos_ids = torch.zeros(height, width, 3) + pos_ids[..., 0] = modality_id + pos_ids[..., 1] = pos_ids[..., 1] + torch.arange(height)[:, None] + start[0] + pos_ids[..., 2] = pos_ids[..., 2] + torch.arange(width)[None, :] + start[1] + pos_ids = pos_ids.reshape(height * width, 3) + else: + raise KeyError(f'Unknow type {type}, only support "text" or "image".') + return pos_ids + + +# Copied from diffusers.pipelines.longcat_image.pipeline_longcat_image.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = width if width % 16 == 0 else (width // 16 + 1) * 16 + height = height if height % 16 == 0 else (height // 16 + 1) * 16 + + width = int(width) + height = int(height) + + return width, height + + +class LongCatImageEditPipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + The LongCat-Image-Edit pipeline for image editing. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + text_processor: Qwen2VLProcessor, + transformer: LongCatImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_processor=text_processor, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor_vl = text_processor.image_processor + + self.image_token = "<|image_pad|>" + self.prompt_template_encode_prefix = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + self.prompt_template_encode_suffix = "<|im_end|>\n<|im_start|>assistant\n" + self.default_sample_size = 128 + self.tokenizer_max_length = 512 + + def _encode_prompt(self, prompt, image): + raw_vl_input = self.image_processor_vl(images=image, return_tensors="pt") + pixel_values = raw_vl_input["pixel_values"] + image_grid_thw = raw_vl_input["image_grid_thw"] + all_tokens = [] + for clean_prompt_sub, matched in split_quotation(prompt[0]): + if matched: + for sub_word in clean_prompt_sub: + tokens = self.tokenizer(sub_word, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + else: + tokens = self.tokenizer(clean_prompt_sub, add_special_tokens=False)["input_ids"] + all_tokens.extend(tokens) + + if len(all_tokens) > self.tokenizer_max_length: + logger.warning( + "Your input was truncated because `max_sequence_length` is set to " + f" {self.tokenizer_max_length} input token nums : {len(len(all_tokens))}" + ) + all_tokens = all_tokens[: self.tokenizer_max_length] + + text_tokens_and_mask = self.tokenizer.pad( + {"input_ids": [all_tokens]}, + max_length=self.tokenizer_max_length, + padding="max_length", + return_attention_mask=True, + return_tensors="pt", + ) + + text = self.prompt_template_encode_prefix + + merge_length = self.image_processor_vl.merge_size**2 + while self.image_token in text: + num_image_tokens = image_grid_thw.prod() // merge_length + text = text.replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + text = text.replace("<|placeholder|>", self.image_token) + + prefix_tokens = self.tokenizer(text, add_special_tokens=False)["input_ids"] + suffix_tokens = self.tokenizer(self.prompt_template_encode_suffix, add_special_tokens=False)["input_ids"] + + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + prefix_len = prefix_tokens.index(vision_start_token_id) + suffix_len = len(suffix_tokens) + + prefix_tokens_mask = torch.tensor([1] * len(prefix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + suffix_tokens_mask = torch.tensor([1] * len(suffix_tokens), dtype=text_tokens_and_mask.attention_mask[0].dtype) + + prefix_tokens = torch.tensor(prefix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + suffix_tokens = torch.tensor(suffix_tokens, dtype=text_tokens_and_mask.input_ids.dtype) + + input_ids = torch.cat((prefix_tokens, text_tokens_and_mask.input_ids[0], suffix_tokens), dim=-1) + attention_mask = torch.cat( + (prefix_tokens_mask, text_tokens_and_mask.attention_mask[0], suffix_tokens_mask), dim=-1 + ) + + input_ids = input_ids.unsqueeze(0).to(self.device) + attention_mask = attention_mask.unsqueeze(0).to(self.device) + + pixel_values = pixel_values.to(self.device) + image_grid_thw = image_grid_thw.to(self.device) + + text_output = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + ) + # [max_sequence_length, batch, hidden_size] -> [batch, max_sequence_length, hidden_size] + # clone to have a contiguous tensor + prompt_embeds = text_output.hidden_states[-1].detach() + prompt_embeds = prompt_embeds[:, prefix_len:-suffix_len, :] + return prompt_embeds + + def encode_prompt( + self, + prompt: List[str] = None, + image: Optional[torch.Tensor] = None, + num_images_per_prompt: Optional[int] = 1, + prompt_embeds: Optional[torch.Tensor] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is None: + prompt_embeds = self._encode_prompt(prompt, image) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = prepare_pos_ids(modality_id=0, type="text", start=(0, 0), num_token=prompt_embeds.shape[1]).to( + self.device + ) + return prompt_embeds, text_ids + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + return image_latents + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + prompt_embeds_length, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + image_latents, image_latents_ids = None, None + + if image is not None: + image = image.to(device=self.device, dtype=dtype) + + if image.shape[1] != self.vae.config.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) + + image_latents_ids = prepare_pos_ids( + modality_id=2, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device, dtype=torch.float64) + + shape = (batch_size, num_channels_latents, height, width) + latents_ids = prepare_pos_ids( + modality_id=1, + type="image", + start=(prompt_embeds_length, prompt_embeds_length), + height=height // 2, + width=width // 2, + ).to(device) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents, latents_ids, image_latents_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + def check_inputs( + self, prompt, height, width, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None: + if isinstance(prompt, str): + pass + elif isinstance(prompt, list) and len(prompt) == 1: + pass + else: + raise ValueError( + f"`prompt` must be a `str` or a `list` of length 1, but is {prompt} (type: {type(prompt)})" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + @replace_example_docstring(EXAMPLE_DOC_STRING) + @torch.no_grad() + def __call__( + self, + image: Optional[PIL.Image.Image] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Examples: + + Returns: + [`~pipelines.LongCatImagePipelineOutput`] or `tuple`: [`~pipelines.LongCatImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] * 1.0 / image_size[1]) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + calculated_height, + calculated_width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = self.image_processor.resize(image, calculated_height // 2, calculated_width // 2) + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + + negative_prompt = "" if negative_prompt is None else negative_prompt + (prompt_embeds, text_ids) = self.encode_prompt( + prompt=prompt, image=prompt_image, prompt_embeds=prompt_embeds, num_images_per_prompt=num_images_per_prompt + ) + if self.do_classifier_free_guidance: + (negative_prompt_embeds, negative_text_ids) = self.encode_prompt( + prompt=negative_prompt, + image=prompt_image, + prompt_embeds=negative_prompt_embeds, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = 16 + latents, image_latents, latents_ids, image_latents_ids = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + calculated_height, + calculated_width, + prompt_embeds.dtype, + prompt_embeds.shape[1], + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = None + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + if image is not None: + latent_image_ids = torch.cat([latents_ids, image_latents_ids], dim=0) + else: + latent_image_ids = latents_ids + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred_text = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_text = noise_pred_text[:, :image_seq_len] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred_uncond = noise_pred_uncond[:, :image_seq_len] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred = noise_pred_text + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, calculated_height, calculated_width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + if latents.dtype != self.vae.dtype: + latents = latents.to(dtype=self.vae.dtype) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return LongCatImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/longcat_image/pipeline_output.py b/src/diffusers/pipelines/longcat_image/pipeline_output.py new file mode 100644 index 000000000000..e3c25f1cbfa7 --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class LongCatImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/longcat_image/system_messages.py b/src/diffusers/pipelines/longcat_image/system_messages.py new file mode 100644 index 000000000000..b8b2318e4e81 --- /dev/null +++ b/src/diffusers/pipelines/longcat_image/system_messages.py @@ -0,0 +1,142 @@ +SYSTEM_PROMPT_EN = """ +You are a prompt engineering expert for text-to-image models. Since text-to-image models have limited capabilities in +understanding user prompts, you need to identify the core theme and intent of the user's input and improve the model's +understanding accuracy and generation quality through optimization and rewriting. The rewrite must strictly retain all +information from the user's original prompt without deleting or distorting any details. Specific requirements are as +follows: +1. The rewrite must not affect any information expressed in the user's original prompt; the rewritten prompt should use + coherent natural language, avoid low-information redundant descriptions, and keep the rewritten prompt length as + concise as possible. +2. Ensure consistency between input and output languages: Chinese input yields Chinese output, and English input yields + English output. The rewritten token count should not exceed 512. +3. The rewritten description should further refine subject characteristics and aesthetic techniques appearing in the + original prompt, such as lighting and textures. +4. If the original prompt does not specify an image style, ensure the rewritten prompt uses a **realistic photography + style**. If the user specifies a style, retain the user's style. +5. When the original prompt requires reasoning to clarify user intent, use logical reasoning based on world knowledge + to convert vague abstract descriptions into specific tangible objects (e.g., convert "the tallest animal" to "a + giraffe"). +6. When the original prompt requires text generation, please use double quotes to enclose the text part (e.g., `"50% + OFF"`). +7. When the original prompt requires generating text-heavy scenes like webpages, logos, UIs, or posters, and no + specific text content is specified, you need to infer appropriate text content and enclose it in double quotes. For + example, if the user inputs: "A tourism flyer with a grassland theme," it should be rewritten as: "A tourism flyer + with the image title 'Grassland'." +8. When negative words exist in the original prompt, ensure the rewritten prompt does not contain negative words. For + example, "a lakeside without boats" should be rewritten such that the word "boat" does not appear at all. +9. Except for text content explicitly requested by the user, **adding any extra text content is prohibited**. +Here are examples of rewrites for different types of prompts: # Examples (Few-Shot Learning) + 1. User Input: An animal with nine lives. + Rewrite Output: A cat bathed in soft sunlight, its fur soft and glossy. The background is a comfortable home + environment with light from the window filtering through curtains, creating a warm light and shadow effect. The + shot uses a medium distance perspective to highlight the cat's leisurely and stretched posture. Light cleverly hits + the cat's face, emphasizing its spirited eyes and delicate whiskers, adding depth and affinity to the image. + 2. User Input: Create an anime-style tourism flyer with a grassland theme. + Rewrite Output: In the lower right of the center, a short-haired girl sits sideways on a gray, irregularly shaped + rock. She wears a white short-sleeved dress and brown flat shoes, holding a bunch of small white flowers in her + left hand, smiling with her legs hanging naturally. The girl has dark brown shoulder-length hair with bangs + covering her forehead, brown eyes, and a slightly open mouth. The rock surface has textures of varying depths. To + the girl's left and front is lush grass, with long, yellow-green blades, some glowing golden in the sunlight. The + grass extends into the distance, forming rolling green hills that fade in color as they recede. The sky occupies + the upper half of the picture, pale blue dotted with a few fluffy white clouds. In the upper left corner, there is + a line of text in italic, dark green font reading "Explore Nature's Peace". Colors are dominated by green, blue, + and yellow, fluid lines, and distinct light and shadow contrast, creating a quiet and comfortable atmosphere. + 3. User Input: A Christmas sale poster with a red background, promoting a Buy 1 Get 1 Free milk tea offer. + Rewrite Output: The poster features an overall red tone, embellished with white snowflake patterns on the top and + left side. The upper right features a bunch of holly leaves with red berries and a pine cone. In the upper center, + golden 3D text reads "Christmas Heartwarming Feedback" centered, along with red bold text "Buy 1 Get 1". Below, two + transparent cups filled with bubble tea are placed side by side; the tea is light brown with dark brown pearls + scattered at the bottom and middle. Below the cups, white snow piles up, decorated with pine branches, red berries, + and pine cones. A blurry Christmas tree is faintly visible in the lower right corner. The image has high clarity, + accurate text content, a unified design style, a prominent Christmas theme, and a reasonable layout, providing + strong visual appeal. + 4. User Input: A woman indoors shot in natural light, smiling with arms crossed, showing a relaxed and confident + posture. + Rewrite Output: The image features a young Asian woman with long dark brown hair naturally falling over her + shoulders, with some strands illuminated by light, showing a soft sheen. Her features are delicate, with long + eyebrows, bright and spirited dark brown eyes looking directly at the camera, revealing peace and confidence. She + has a high nose bridge, full lips with nude lipstick, and corners of the mouth slightly raised in a faint smile. + Her skin is fair, with cheeks and collarbones illuminated by warm light, showing a healthy ruddiness. She wears a + black spaghetti strap tank top revealing graceful collarbone lines, and a thin gold necklace with small beads and + metal bars glinting in the light. Her outer layer is a beige knitted cardigan, soft in texture with visible + knitting patterns on the sleeves. Her arms are crossed over her chest, hands covered by the cardigan sleeves, in a + relaxed posture. The background is a pure dark brown without extra decoration, making the figure the absolute + focus. The figure is located in the center of the frame. Light enters from the upper right, creating bright spots + on her left cheek, neck, and collarbone, while the right side is slightly shadowed, creating a three-dimensional + and soft tone. Image details are clear, showcasing skin texture, hair, and clothing materials well. Colors are + dominated by warm tones, with the combination of beige and dark brown creating a warm and comfortable atmosphere. + The overall style is natural, elegant, and artistic. + 5. User Input: Create a series of images showing the growth process of an apple from seed to fruit. The series should + include four stages: 1. Sowing, 2. Seedling growth, 3. Plant maturity, 4. Fruit harvesting. + Rewrite Output: A 4-panel exquisite illustration depicting the growth process of an apple, capturing each stage + precisely and clearly. 1. "Sowing": A close-up shot of a hand gently placing a small apple seed into fertile dark + soil, with visible soil texture and the seed's smooth surface. The background is a soft-focus garden dotted with + green leaves and sunlight filtering through. 2. "Seedling Growth": A young apple sapling breaks through the soil, + stretching tender green leaves toward the sky. The scene is set in a vibrant garden illuminated by warm golden + light, highlighting the seedling's delicate structure. 3. "Plant Maturity": A mature apple tree, lush with branches + and leaves, covered in tender green foliage and developing small apples. The background is a vibrant orchard under + a clear blue sky, with dappled sunlight creating a peaceful atmosphere. 4. "Fruit Harvesting": A hand reaches into + the tree to pick a ripe red apple, its smooth skin glistening in the sun. The scene shows the abundance of the + orchard, with baskets of apples in the background, giving a sense of fulfillment. Each illustration uses a + realistic style, focusing on details and harmonious colors to showcase the natural beauty and development of the + apple's life cycle. + 6. User Input: If 1 represents red, 2 represents green, 3 represents purple, and 4 represents yellow, please generate + a four-color rainbow based on this rule. The color order from top to bottom is 3142. + Rewrite Output: The image consists of four horizontally arranged colored stripes, ordered from top to bottom as + purple, red, yellow, and green. A white number is centered on each stripe. The top purple stripe features the + number "3", the red stripe below it has the number "1", the yellow stripe further down has the number "4", and the + bottom green stripe has the number "2". All numbers use a sans-serif font in pure white, forming a sharp contrast + with the background colors to ensure good readability. The stripes have high color saturation and a slight texture. + The overall layout is simple and clear, with distinct visual effects and no extra decorative elements, emphasizing + the numerical information. The image is high definition, with accurate colors and a consistent style, offering + strong visual appeal. + 7. User Input: A stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", natural light, background is a + Chinese garden. + Rewrite Output: An ancient stone tablet carved with "Guan Guan Ju Jiu, On the River Isle", the surface covered with + traces of time, the writing clear and deep. Natural light falls from above, softly illuminating every detail of the + stone tablet and enhancing its sense of history. The background is an elegant Chinese garden featuring lush bamboo + forests, winding paths, and quiet pools, creating a serene and distant atmosphere. The overall picture uses a + realistic style with rich details and natural light and shadow effects, highlighting the cultural heritage of the + stone tablet and the classical beauty of the garden. +# Output Format Please directly output the rewritten and optimized Prompt content. Do not include any explanatory +language or JSON formatting, and do not add opening or closing quotes yourself.""" + + +SYSTEM_PROMPT_ZH = """ +你是一名文生图模型的prompt +engineering专家。由于文生图模型对用户prompt的理解能力有限,你需要识别用户输入的核心主题和意图,并通过优化改写提升模型的理解准确性和生成质量。改写必须严格保留用户原始prompt的所有信息,不得删减或曲解任何细节。 +具体要求如下: +1. 改写不能影响用户原始prompt里表达的任何信息,改写后的prompt应该使用连贯的自然语言表达,不要出现低信息量的冗余描述,尽可能保持改写后prompt长度精简。 +2. 请确保输入和输出的语言类型一致,中文输入中文输出,英文输入英文输出,改写后的token数量不要超过512个; +3. 改写后的描述应当进一步完善原始prompt中出现的主体特征、美学技巧,如打光、纹理等; +4. 如果原始prompt没有指定图片风格时,确保改写后的prompt使用真实摄影风格,如果用户指定了图片风格,则保留用户风格; +5. 当原始prompt需要推理才能明确用户意图时,根据世界知识进行适当逻辑推理,将模糊抽象描述转化为具体指向事物(例:将"最高的动物"转化为"一头长颈鹿")。 +6. 当原始prompt需要生成文字时,请使用双引号圈定文字部分,例:`"限时5折"`)。 +7. 当原始prompt需要生成网页、logo、ui、海报等文字场景时,且没有指定具体的文字内容时,需要推断出合适的文字内容,并使用双引号圈定,如用户输入:一个旅游宣传单,以草原为主题。应该改写成:一个旅游宣传单,图片标题为“草原”。 +8. 当原始prompt中存在否定词时,需要确保改写后的prompt不存在否定词,如没有船的湖边,改写后的prompt不能出现船这个词汇。 +9. 除非用户指定生成品牌logo,否则不要增加额外的品牌logo. +10. 除了用户明确要求书写的文字内容外,**禁止增加任何额外的文字内容**。 +以下是针对不同类型prompt改写的示例: + +# Examples (Few-Shot Learning) + 1. 用户输入: 九条命的动物。 + 改写输出: + 一只猫,被柔和的阳光笼罩着,毛发柔软而富有光泽。背景是一个舒适的家居环境,窗外的光线透过窗帘,形成温馨的光影效果。镜头采用中距离视角,突出猫悠闲舒展的姿态。光线巧妙地打在猫的脸部,强调它灵动的眼睛和精致的胡须,增加画面的层次感与亲和力。 + 2. 用户输入: 制作一个动画风格的旅游宣传单,以草原为主题。 + 改写输出: + 画面中央偏右下角,一个短发女孩侧身坐在灰色的不规则形状岩石上,她穿着白色短袖连衣裙和棕色平底鞋,左手拿着一束白色小花,面带微笑,双腿自然垂下。女孩的头发为深棕色,齐肩短发,刘海覆盖额头,眼睛呈棕色,嘴巴微张。岩石表面有深浅不一的纹理。女孩的左侧和前方是茂盛的草地,草叶细长,呈黄绿色,部分草叶在阳光下泛着金色的光芒,仿佛被阳光照亮。草地向远处延伸,形成连绵起伏的绿色山丘,山丘的颜色由近及远逐渐变浅。天空占据了画面的上半部分,呈淡蓝色,点缀着几朵白色蓬松的云彩。画面的左上角有一行文字,文字内容是斜体、深绿色的“Explore + Nature's Peace”。色彩以绿色、蓝色和黄色为主,线条流畅,光影明暗对比明显,营造出一种宁静、舒适的氛围。 + 3. 用户输入: 一张以红色为背景的圣诞节促销海报,主要宣传奶茶买一送一的优惠活动。 + 改写输出: 海报整体呈现红色调,上方和左侧点缀着白色雪花图案,右上方有一束冬青叶和红色浆果,以及一个松果。海报中央偏上位置,金色立体字样“圣诞节 + 暖心回馈”居中排列,和红色粗体字“买1送1”。海报下方,两个装满珍珠奶茶的透明杯子并排摆放,杯中奶茶呈浅棕色,底部和中间散布着深棕色珍珠。杯子下方,堆积着白色雪花,雪花上装饰着松枝、红色浆果和松果。右下角隐约可见一棵模糊的圣诞树。图片清晰度高,文字内容准确,整体设计风格统一,圣诞主题突出,排版布局合理,具有较强的视觉吸引力。 + 4. 用户输入: 一位女性在室内以自然光线拍摄,她面带微笑,双臂交叉,展现出轻松自信的姿态。 + 改写输出: + 画面中是一位年轻的亚洲女性,她拥有深棕色的长发,发丝自然地垂落在双肩,部分发丝被光线照亮,呈现出柔和的光泽。她的五官精致,眉毛修长,眼睛明亮有神,瞳孔呈深棕色,眼神直视镜头,流露出平和与自信。鼻梁挺拔,嘴唇丰满,涂有裸色系唇膏,嘴角微微上扬,展现出浅浅的微笑。她的肤色白皙,脸颊和锁骨处被暖色调的光线照亮,呈现出健康的红润感。她穿着一件黑色的细吊带背心,肩带纤细,露出优美的锁骨线条。脖颈上佩戴着一条金色的细项链,项链由小珠子和几个细长的金属条组成,在光线下闪烁着光泽。她的外搭是一件米黄色的针织开衫,材质柔软,袖子部分有明显的针织纹理。她双臂交叉在胸前,双手被开衫的袖子覆盖,姿态放松。背景是纯粹的深棕色,没有多余的装饰,使得人物成为画面的绝对焦点。人物位于画面中央。光线从画面的右上方射入,在人物的左侧脸颊、脖颈和锁骨处形成明亮的光斑,右侧则略显阴影,营造出立体感和柔和的影调。图像细节清晰,人物的皮肤纹理、发丝以及衣物材质都得到了很好的展现。色彩以暖色调为主,米黄色和深棕色的搭配营造出温馨舒适的氛围。整体呈现出一种自然、优雅且富有亲和力的艺术风格。 + 5. 用户输入:创作一系列图片,展现苹果从种子到结果的生长过程。该系列图片应包含以下四个阶段:1. 播种,2. 幼苗生长,3. 植物成熟,4. 果实采摘。 + 改写输出:一个4宫格的精美插图,描绘苹果的生长过程,精确清晰地捕捉每个阶段。1.“播种”:特写镜头,一只手轻轻地将一颗小小的苹果种子放入肥沃的深色土壤中,土壤的纹理和种子光滑的表面清晰可见。背景是花园的柔焦画面,点缀着绿色的树叶和透过树叶洒下的阳光。2.“幼苗生长”:一棵幼小的苹果树苗破土而出,嫩绿的叶子向天空舒展。场景设定在一个生机勃勃的花园中,温暖的金光照亮了它。幼苗的纤细结构。3.“植物的成熟”:一棵成熟的苹果树,枝繁叶茂,挂满了嫩绿的叶子和正在萌发的小苹果。背景是一片生机勃勃的果园,湛蓝的天空下,斑驳的阳光营造出宁静祥和的氛围。4.“采摘果实”:一只手伸向树上,摘下一个成熟的红苹果,苹果光滑的果皮在阳光下闪闪发光。画面展现了果园的丰收景象,背景中摆放着一篮篮的苹果,给人一种圆满满足的感觉。每幅插图都采用写实风格,注重细节,色彩和谐,展现了苹果生命周期的自然之美和发展过程。 + 6. 用户输入: 如果1代表红色,2代表绿色,3代表紫色,4代表黄色,请按照此规则生成四色彩虹。它的颜色顺序从上到下是3142 + 改写输出:图片由四个水平排列的彩色条纹组成,从上到下依次为紫色、红色、黄色和绿色。每个条纹上都居中放置一个白色数字。最上方的紫色条纹上是数字“3”,其下方红色条纹上是数字“1”,再下方黄色条纹上是数字“4”,最下方的绿色条纹上是数字“2”。所有数字均采用无衬线字体,颜色为纯白色,与背景色形成鲜明对比,确保了良好的可读性。条纹的颜色饱和度高,且带有轻微的纹理感,整体排版简洁明了,视觉效果清晰,没有多余的装饰元素,强调了数字信息本身。图片整体清晰度高,色彩准确,风格一致,具有较强的视觉吸引力。 + 7. 用户输入:石碑上刻着“关关雎鸠,在河之洲”,自然光照,背景是中式园林 + 改写输出:一块古老的石碑上刻着“关关雎鸠,在河之洲”,石碑表面布满岁月的痕迹,字迹清晰而深刻。自然光线从上方洒下,柔和地照亮石碑的每一个细节,增强了其历史感。背景是一座典雅的中式园林,园林中有翠绿的竹林、蜿蜒的小径和静谧的水池,营造出一种宁静而悠远的氛围。整体画面采用写实风格,细节丰富,光影效果自然,突出了石碑的文化底蕴和园林的古典美。 +# 输出格式 请直接输出改写优化后的 Prompt 内容,不要包含任何解释性语言或 JSON 格式,不要自行添加开头或结尾的引号。 +""" diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py index 6001867916b3..05117d35d3b4 100644 --- a/src/diffusers/pipelines/ltx/__init__.py +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -25,6 +25,7 @@ _import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"] _import_structure["pipeline_ltx"] = ["LTXPipeline"] _import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"] + _import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] _import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"] @@ -39,6 +40,7 @@ from .modeling_latent_upsampler import LTXLatentUpsamplerModel from .pipeline_ltx import LTXPipeline from .pipeline_ltx_condition import LTXConditionPipeline + from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 8ca8b4419e18..3c90da1c7051 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -726,10 +726,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, timesteps, sigmas=sigmas, mu=mu, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index 48a6f0837c8d..10c9432a7f46 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1102,11 +1102,24 @@ def __call__( latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio + if timesteps is None: sigmas = linear_quadratic_schedule(num_inference_steps) timesteps = sigmas * 1000 - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + timestep_device, + timesteps, + ) sigmas = self.scheduler.sigmas + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) latent_sigma = None diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py b/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py new file mode 100644 index 000000000000..7965bd3b4b87 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py @@ -0,0 +1,1408 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler, LTXEulerAncestralRFScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXEulerAncestralRFScheduler, LTXI2VLongMultiPromptPipeline + + >>> pipe = LTXI2VLongMultiPromptPipeline.from_pretrained("LTX-Video-0.9.8-13B-distilled") + >>> # For ComfyUI parity, swap in the RF scheduler (keeps the original config). + >>> pipe.scheduler = LTXEulerAncestralRFScheduler.from_config(pipe.scheduler.config) + >>> pipe = pipe.to("cuda").to(dtype=torch.bfloat16) + >>> # Example A: get decoded frames (PIL) + >>> out = pipe( + ... prompt="a chimpanzee walks | a chimpanzee eats", + ... num_frames=161, + ... height=512, + ... width=704, + ... temporal_tile_size=80, + ... temporal_overlap=24, + ... output_type="pil", + ... return_dict=True, + ... ) + >>> frames = out.frames[0] # list of PIL.Image.Image + >>> # Example B: get latent video and decode later (saves VRAM during sampling) + >>> out_latent = pipe(prompt="a chimpanzee walking", output_type="latent", return_dict=True).frames + >>> frames = pipe.vae_decode_tiled(out_latent, output_type="pil")[0] + ``` +""" + + +def get_latent_coords( + latent_num_frames, latent_height, latent_width, batch_size, device, rope_interpolation_scale, latent_idx +): + """ + Compute latent patch top-left coordinates in (t, y, x) order. + + Args: + latent_num_frames: int. Number of latent frames (T_lat). + latent_height: int. Latent height (H_lat). + latent_width: int. Latent width (W_lat). + batch_size: int. Batch dimension (B). + device: torch.device for the resulting tensor. + rope_interpolation_scale: + tuple[int|float, int|float, int|float]. Scale per (t, y, x) latent step to pixel coords. + latent_idx: Optional[int]. When not None, shifts the time coordinate to align segments: + - <= 0 uses step multiples of rope_interpolation_scale[0] + - > 0 starts at 1 then increments by rope_interpolation_scale[0] + + Returns: + Tensor of shape [B, 3, T_lat * H_lat * W_lat] containing top-left coordinates per latent patch, repeated for each + batch element. + """ + latent_sample_coords = torch.meshgrid( + torch.arange(0, latent_num_frames, 1, device=device), + torch.arange(0, latent_height, 1, device=device), + torch.arange(0, latent_width, 1, device=device), + indexing="ij", + ) + latent_sample_coords = torch.stack(latent_sample_coords, dim=0) + latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_coords = latent_coords.flatten(2) + pixel_coords = latent_coords * torch.tensor(rope_interpolation_scale, device=latent_coords.device)[None, :, None] + if latent_idx is not None: + if latent_idx <= 0: + frame_idx = latent_idx * rope_interpolation_scale[0] + else: + frame_idx = 1 + (latent_idx - 1) * rope_interpolation_scale[0] + if frame_idx == 0: + pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - rope_interpolation_scale[0]).clamp(min=0) + pixel_coords[:, 0] += frame_idx + return pixel_coords + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def adain_normalize_latents( + curr_latents: torch.Tensor, ref_latents: Optional[torch.Tensor], factor: float +) -> torch.Tensor: + """ + Optional AdaIN normalization: channel-wise mean/variance matching of curr_latents to ref_latents, controlled by + factor. + + Args: + curr_latents: Tensor [B, C, T, H, W]. Current window latents. + ref_latents: + Optional[Tensor] [B, C, T_ref, H, W]. Reference latents (e.g., first window) used to compute target stats. + factor: float in [0, 1]. 0 keeps current stats; 1 matches reference stats. + + Returns: + Tensor with per-channel mean/std blended towards the reference. + """ + if ref_latents is None or factor is None or factor <= 0: + return curr_latents + + eps = torch.tensor(1e-6, device=curr_latents.device, dtype=curr_latents.dtype) + + # Compute per-channel means/stds for current and reference over (T, H, W) + mu_curr = curr_latents.mean(dim=(2, 3, 4), keepdim=True) + sigma_curr = curr_latents.std(dim=(2, 3, 4), keepdim=True) + + mu_ref = ref_latents.mean(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype) + sigma_ref = ref_latents.std(dim=(2, 3, 4), keepdim=True).to(device=curr_latents.device, dtype=curr_latents.dtype) + + # Blend target statistics + mu_blend = (1.0 - float(factor)) * mu_curr + float(factor) * mu_ref + sigma_blend = (1.0 - float(factor)) * sigma_curr + float(factor) * sigma_ref + sigma_blend = torch.clamp(sigma_blend, min=float(eps)) + + # Apply AdaIN + curr_norm = (curr_latents - mu_curr) / (sigma_curr + eps) + return curr_norm * sigma_blend + mu_blend + + +def split_into_temporal_windows( + latent_len: int, temporal_tile_size: int, temporal_overlap: int, compression: int +) -> List[Tuple[int, int]]: + """ + Split latent frames into sliding windows. + + Args: + latent_len: int. Number of latent frames (T_lat). + temporal_tile_size: int. Window size in latent frames (> 0). + temporal_overlap: int. Overlap between windows in latent frames (>= 0). + compression: int. VAE temporal compression ratio (unused here; kept for parity). + + Returns: + list[tuple[int, int]]: inclusive-exclusive (start, end) indices per window. + """ + if temporal_tile_size <= 0: + raise ValueError("temporal_tile_size must be > 0") + stride = max(temporal_tile_size - temporal_overlap, 1) + windows = [] + start = 0 + while start < latent_len: + end = min(start + temporal_tile_size, latent_len) + windows.append((start, end)) + if end == latent_len: + break + start = start + stride + return windows + + +def linear_overlap_fuse(prev: torch.Tensor, new: torch.Tensor, overlap: int) -> torch.Tensor: + """ + Temporal linear crossfade between two latent clips over the overlap region. + + Args: + prev: Tensor [B, C, F, H, W]. Previous output segment. + new: Tensor [B, C, F, H, W]. New segment to be appended. + overlap: int. Number of frames to crossfade (overlap <= 1 concatenates without blend). + + Returns: + Tensor [B, C, F_prev + F_new - overlap, H, W] after crossfade at the seam. + """ + if overlap <= 1: + return torch.cat([prev, new], dim=2) + alpha = torch.linspace(1, 0, overlap + 2, device=prev.device, dtype=prev.dtype)[1:-1] + shape = [1] * prev.ndim + shape[2] = alpha.size(0) + alpha = alpha.reshape(shape) + blended = alpha * prev[:, :, -overlap:] + (1 - alpha) * new[:, :, :overlap] + return torch.cat([prev[:, :, :-overlap], blended, new[:, :, overlap:]], dim=2) + + +def inject_prev_tail_latents( + window_latents: torch.Tensor, + prev_tail_latents: Optional[torch.Tensor], + window_cond_mask_5d: torch.Tensor, + overlap_lat: int, + strength: Optional[float], + prev_overlap_len: int, +) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Inject the tail latents from the previous window at the beginning of the current window (first k frames), where k = + min(overlap_lat, T_curr, T_prev_tail). + + Args: + window_latents: Tensor [B, C, T, H, W]. Current window latents. + prev_tail_latents: Optional[Tensor] [B, C, T_prev, H, W]. Tail segment from the previous window. + window_cond_mask_5d: Tensor [B, 1, T, H, W]. Per-token conditioning mask (1 = free, 0 = hard condition). + overlap_lat: int. Number of latent frames to inject from the previous tail. + strength: Optional[float] in [0, 1]. Blend strength; 1.0 replaces, 0.0 keeps original. + prev_overlap_len: int. Accumulated overlap length so far (used for trimming later). + + Returns: + Tuple[Tensor, Tensor, int]: (updated_window_latents, updated_cond_mask, updated_prev_overlap_len) + """ + if prev_tail_latents is None or overlap_lat <= 0 or strength is None or strength <= 0: + return window_latents, window_cond_mask_5d, prev_overlap_len + + # Expected shape: [B, C, T, H, W] + T = int(window_latents.shape[2]) + k = min(int(overlap_lat), T, int(prev_tail_latents.shape[2])) + if k <= 0: + return window_latents, window_cond_mask_5d, prev_overlap_len + + tail = prev_tail_latents[:, :, -k:] + mask = torch.full( + (window_cond_mask_5d.shape[0], 1, tail.shape[2], window_cond_mask_5d.shape[3], window_cond_mask_5d.shape[4]), + 1.0 - strength, + dtype=window_cond_mask_5d.dtype, + device=window_cond_mask_5d.device, + ) + + window_latents = torch.cat([window_latents, tail], dim=2) + window_cond_mask_5d = torch.cat([window_cond_mask_5d, mask], dim=2) + return window_latents, window_cond_mask_5d, prev_overlap_len + k + + +def build_video_coords_for_window( + latents: torch.Tensor, + overlap_len: int, + guiding_len: int, + negative_len: int, + rope_interpolation_scale: torch.Tensor, + frame_rate: int, +) -> torch.Tensor: + """ + Build video_coords: [B, 3, S] with order [t, y, x]. + + Args: + latents: Tensor [B, C, T, H, W]. Current window latents (before any trimming). + overlap_len: int. Number of frames from previous tail injected at the head. + guiding_len: int. Number of guidance frames appended at the head. + negative_len: int. Number of negative-index frames appended at the head (typically 1 or 0). + rope_interpolation_scale: tuple[int|float, int|float, int|float]. Scale for (t, y, x). + frame_rate: int. Used to convert time indices into seconds (t /= frame_rate). + + Returns: + Tensor [B, 3, T*H*W] of fractional pixel coordinates per latent patch. + """ + + b, c, f, h, w = latents.shape + pixel_coords = get_latent_coords(f, h, w, b, latents.device, rope_interpolation_scale, 0) + replace_corrds = [] + if overlap_len > 0: + replace_corrds.append(get_latent_coords(overlap_len, h, w, b, latents.device, rope_interpolation_scale, 0)) + if guiding_len > 0: + replace_corrds.append( + get_latent_coords(guiding_len, h, w, b, latents.device, rope_interpolation_scale, overlap_len) + ) + if negative_len > 0: + replace_corrds.append(get_latent_coords(negative_len, h, w, b, latents.device, rope_interpolation_scale, -1)) + if len(replace_corrds) > 0: + replace_corrds = torch.cat(replace_corrds, axis=2) + pixel_coords[:, :, -replace_corrds.shape[2] :] = replace_corrds + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + return fractional_coords + + +def parse_prompt_segments(prompt: Union[str, List[str]], prompt_segments: Optional[List[Dict[str, Any]]]) -> List[str]: + """ + Return a list of positive prompts per window index. + + Args: + prompt: str | list[str]. If str contains '|', parts are split by bars and trimmed. + prompt_segments: + list[dict], optional. Each dict with {"start_window", "end_window", "text"} overrides prompts per window. + + Returns: + list[str] containing the positive prompt for each window index. + """ + if prompt is None: + return [] + if prompt_segments: + max_w = 0 + for seg in prompt_segments: + max_w = max(max_w, int(seg.get("end_window", 0))) + texts = [""] * (max_w + 1) + for seg in prompt_segments: + s = int(seg.get("start_window", 0)) + e = int(seg.get("end_window", s)) + txt = seg.get("text", "") + for w in range(s, e + 1): + texts[w] = txt + # fill empty by last non-empty + last = "" + for i in range(len(texts)): + if texts[i] == "": + texts[i] = last + else: + last = texts[i] + return texts + + # bar-split mode + if isinstance(prompt, str): + parts = [p.strip() for p in prompt.split("|")] + else: + parts = prompt + parts = [p for p in parts if p is not None] + return parts + + +def batch_normalize(latents, reference, factor): + """ + Batch AdaIN-like normalization for latents in dict format (ComfyUI-compatible). + + Args: + latents: dict containing "samples" shaped [B, C, F, H, W] + reference: dict containing "samples" used to compute target stats + factor: float in [0, 1]; 0 = no change, 1 = full match to reference + Returns: + Tuple[dict]: a single-element tuple with the updated latents dict. + """ + latents_copy = copy.deepcopy(latents) + t = latents_copy["samples"] # B x C x F x H x W + + for i in range(t.size(0)): # batch + for c in range(t.size(1)): # channel + r_sd, r_mean = torch.std_mean(reference["samples"][i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(t[i, c], dim=None) + + t[i, c] = ((t[i, c] - i_mean) / i_sd) * r_sd + r_mean + + latents_copy["samples"] = torch.lerp(latents["samples"], t, factor) + return (latents_copy,) + + +class LTXI2VLongMultiPromptPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Long-duration I2V (image-to-video) multi-prompt pipeline with ComfyUI parity. + + Key features: + - Temporal sliding-window sampling only (no spatial H/W sharding); autoregressive fusion across windows. + - Multi-prompt segmentation per window with smooth transitions at window heads. + - First-frame hard conditioning via per-token mask for I2V. + - VRAM control via temporal windowing and VAE tiled decoding. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`] or [`LTXEulerAncestralRFScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + if not isinstance(scheduler, LTXEulerAncestralRFScheduler): + logger.warning( + "For ComfyUI parity, `LTXI2VLongMultiPromptPipeline` is typically run with " + "`LTXEulerAncestralRFScheduler`. Got %s.", + scheduler.__class__.__name__, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer", None) is not None else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + self._current_tile_T = None + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.guidance_scale + def guidance_scale(self): + return self._guidance_scale + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.guidance_rescale + def guidance_rescale(self): + return self._guidance_rescale + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.do_classifier_free_guidance + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.num_timesteps + def num_timesteps(self): + return self._num_timesteps + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.current_timestep + def current_timestep(self): + return self._current_timestep + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.attention_kwargs + def attention_kwargs(self): + return self._attention_kwargs + + @property + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.interrupt + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + device: torch.device, + generator: Optional[torch.Generator], + dtype: torch.dtype = torch.float32, + latents: Optional[torch.Tensor] = None, + cond_latents: Optional[torch.Tensor] = None, + cond_strength: float = 0.0, + negative_index_latents: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], int, int, int]: + """ + Prepare base latents and optionally inject first-frame conditioning latents. + + Returns: + latents, negative_index_latents, latent_num_frames, latent_height, latent_width + """ + if latents is None: + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = torch.zeros( + (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width), + device=device, + dtype=dtype, + ) + else: + latent_num_frames = latents.shape[2] + latent_height = latents.shape[3] + latent_width = latents.shape[4] + latents = latents.to(device=device, dtype=dtype) + + if cond_latents is not None and cond_strength > 0: + if negative_index_latents is None: + negative_index_latents = cond_latents + latents[:, :, :1, :, :] = cond_latents + + return latents, negative_index_latents, latent_num_frames, latent_height, latent_width + + # TODO: refactor this out + @torch.no_grad() + def vae_decode_tiled( + self, + latents: torch.Tensor, + decode_timestep: Optional[float] = None, + decode_noise_scale: Optional[float] = None, + horizontal_tiles: int = 4, + vertical_tiles: int = 4, + overlap: int = 3, + last_frame_fix: bool = True, + generator: Optional[torch.Generator] = None, + output_type: str = "pt", + auto_denormalize: bool = True, + compute_dtype: torch.dtype = torch.float32, + enable_vae_tiling: bool = False, + ) -> Union[torch.Tensor, np.ndarray, List[PIL.Image.Image]]: + """ + VAE-based spatial tiled decoding (ComfyUI parity) implemented in Diffusers style. + - Linearly feather and blend overlapping tiles to avoid seams. + - Optional last_frame_fix: duplicate the last latent frame before decoding, then drop time_scale_factor frames + at the end. + - Supports timestep_conditioning and decode_noise_scale injection. + - By default, "normalized latents" (the denoising output) are de-normalized internally (auto_denormalize=True). + - Tile fusion is computed in compute_dtype (float32 by default) to reduce blur and color shifts. + + Args: + latents: [B, C_latent, F_latent, H_latent, W_latent] + decode_timestep: Optional decode timestep (effective only if VAE supports timestep_conditioning) + decode_noise_scale: + Optional decode noise interpolation (effective only if VAE supports timestep_conditioning) + horizontal_tiles, vertical_tiles: Number of tiles horizontally/vertically (>= 1) + overlap: Overlap in latent space (in latent pixels, >= 0) + last_frame_fix: Whether to enable the "repeat last frame" fix + generator: Random generator (used for decode_noise_scale noise) + output_type: "latent" | "pt" | "np" | "pil" + - "latent": return latents unchanged (useful for downstream processing) + - "pt": return tensor in VAE output space + - "np"/"pil": post-processed outputs via VideoProcessor.postprocess_video + auto_denormalize: If True, apply LTX de-normalization to `latents` internally (recommended) + compute_dtype: Precision used during tile fusion (float32 default; significantly reduces seam blur) + enable_vae_tiling: If True, delegate tiling to VAE's built-in `tiled_decode` (sets `vae.use_tiling`). + + Returns: + - If output_type="latent": returns input `latents` unchanged + - If output_type="pt": returns [B, C, F, H, W] (values roughly in [-1, 1]) + - If output_type="np"/"pil": returns post-processed outputs via postprocess_video + """ + if output_type == "latent": + return latents + if horizontal_tiles < 1 or vertical_tiles < 1: + raise ValueError("horizontal_tiles and vertical_tiles must be >= 1") + overlap = max(int(overlap), 0) + + # Device and precision + device = self._execution_device + latents = latents.to(device=device, dtype=compute_dtype) + + # De-normalize to VAE space (avoid color artifacts) + if auto_denormalize: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + # dtype required for VAE forward pass + latents = latents.to(dtype=self.vae.dtype) + + # Temporal/spatial upscaling ratios (parity with ComfyUI's downscale_index_formula) + tsf = int(self.vae_temporal_compression_ratio) + sf = int(self.vae_spatial_compression_ratio) + + # Optional: last_frame_fix (repeat last latent frame) + if last_frame_fix: + latents = torch.cat([latents, latents[:, :, -1:].contiguous()], dim=2) + + b, c_lat, f_lat, h_lat, w_lat = latents.shape + f_out = 1 + (f_lat - 1) * tsf + h_out = h_lat * sf + w_out = w_lat * sf + + # timestep_conditioning + decode-time noise injection (aligned with pipeline) + if getattr(self.vae.config, "timestep_conditioning", False): + dt = float(decode_timestep) if decode_timestep is not None else 0.0 + vt = torch.tensor([dt], device=device, dtype=latents.dtype) + if decode_noise_scale is not None: + dns = torch.tensor([float(decode_noise_scale)], device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + latents = (1 - dns) * latents + dns * noise + else: + vt = None + + if enable_vae_tiling and hasattr(self.vae, "enable_tiling"): + self.vae.enable_tiling() + decoded = self.vae.decode(latents, vt, return_dict=False)[0] + if last_frame_fix: + decoded = decoded[:, :, :-tsf, :, :] + if output_type in ("np", "pil"): + return self.video_processor.postprocess_video(decoded, output_type=output_type) + return decoded + + # Compute base tile sizes (in latent space) + base_tile_h = (h_lat + (vertical_tiles - 1) * overlap) // vertical_tiles + base_tile_w = (w_lat + (horizontal_tiles - 1) * overlap) // horizontal_tiles + + output: Optional[torch.Tensor] = None # [B, C_img, F, H, W], fused using compute_dtype + weights: Optional[torch.Tensor] = None # [B, 1, F, H, W], fused using compute_dtype + + # Iterate tiles in latent space (no temporal tiling) + for v in range(vertical_tiles): + for h in range(horizontal_tiles): + h_start = h * (base_tile_w - overlap) + v_start = v * (base_tile_h - overlap) + + h_end = min(h_start + base_tile_w, w_lat) if h < horizontal_tiles - 1 else w_lat + v_end = min(v_start + base_tile_h, h_lat) if v < vertical_tiles - 1 else h_lat + + # Slice latent tile and decode + tile_latents = latents[:, :, :, v_start:v_end, h_start:h_end] + decoded_tile = self.vae.decode(tile_latents, vt, return_dict=False)[0] # [B, C, F, Ht, Wt] + # Cast to high precision to reduce blending blur + decoded_tile = decoded_tile.to(dtype=compute_dtype) + + # Initialize output buffers (compute_dtype) + if output is None: + output = torch.zeros( + (b, decoded_tile.shape[1], f_out, h_out, w_out), + device=decoded_tile.device, + dtype=compute_dtype, + ) + weights = torch.zeros( + (b, 1, f_out, h_out, w_out), + device=decoded_tile.device, + dtype=compute_dtype, + ) + + # Tile placement in output pixel space + out_h_start = v_start * sf + out_h_end = v_end * sf + out_w_start = h_start * sf + out_w_end = h_end * sf + + tile_out_h = out_h_end - out_h_start + tile_out_w = out_w_end - out_w_start + + # Linear feathering weights [B, 1, F, Ht, Wt] (compute_dtype) + tile_weights = torch.ones( + (b, 1, decoded_tile.shape[2], tile_out_h, tile_out_w), + device=decoded_tile.device, + dtype=compute_dtype, + ) + + overlap_out_h = overlap * sf + overlap_out_w = overlap * sf + + # Horizontal feathering: left/right overlaps + if overlap_out_w > 0: + if h > 0: + h_blend = torch.linspace( + 0, 1, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, :, :overlap_out_w] *= h_blend.view(1, 1, 1, 1, -1) + if h < horizontal_tiles - 1: + h_blend = torch.linspace( + 1, 0, steps=overlap_out_w, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, :, -overlap_out_w:] *= h_blend.view(1, 1, 1, 1, -1) + + # Vertical feathering: top/bottom overlaps + if overlap_out_h > 0: + if v > 0: + v_blend = torch.linspace( + 0, 1, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, :overlap_out_h, :] *= v_blend.view(1, 1, 1, -1, 1) + if v < vertical_tiles - 1: + v_blend = torch.linspace( + 1, 0, steps=overlap_out_h, device=decoded_tile.device, dtype=compute_dtype + ) + tile_weights[:, :, :, -overlap_out_h:, :] *= v_blend.view(1, 1, 1, -1, 1) + + # Accumulate blended tile + output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += decoded_tile * tile_weights + weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += tile_weights + + # Normalize, then clamp to [-1, 1] in compute_dtype to avoid color artifacts + output = output / (weights + 1e-8) + output = output.clamp(-1.0, 1.0) + output = output.to(dtype=self.vae.dtype) + + # Optional: drop the last tsf frames after last_frame_fix + if last_frame_fix: + output = output[:, :, :-tsf, :, :] + + if output_type in ("np", "pil"): + return self.video_processor.postprocess_video(output, output_type=output_type) + return output + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_segments: Optional[List[Dict[str, Any]]] = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: float = 25, + guidance_scale: float = 1.0, + guidance_rescale: float = 0.0, + num_inference_steps: Optional[int] = 8, + sigmas: Optional[Union[List[float], torch.Tensor]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + seed: Optional[int] = 0, + cond_image: Optional[Union["PIL.Image.Image", torch.Tensor]] = None, + cond_strength: float = 0.5, + latents: Optional[torch.Tensor] = None, + temporal_tile_size: int = 80, + temporal_overlap: int = 24, + temporal_overlap_cond_strength: float = 0.5, + adain_factor: float = 0.25, + guidance_latents: Optional[torch.Tensor] = None, + guiding_strength: float = 1.0, + negative_index_latents: Optional[torch.Tensor] = None, + negative_index_strength: float = 1.0, + skip_steps_sigma_threshold: Optional[float] = 1, + decode_timestep: Optional[float] = 0.05, + decode_noise_scale: Optional[float] = 0.025, + decode_horizontal_tiles: int = 4, + decode_vertical_tiles: int = 4, + decode_overlap: int = 3, + output_type: Optional[str] = "latent", # "latent" | "pt" | "np" | "pil" + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Generate an image-to-video sequence via temporal sliding windows and multi-prompt scheduling. + + Args: + prompt (`str` or `List[str]`, *optional*): + Positive text prompt(s) per window. If a single string contains '|', parts are split by bars. + negative_prompt (`str` or `List[str]`, *optional*): + Negative prompt(s) to suppress undesired content. + prompt_segments (`List[dict]`, *optional*): + Segment mapping with {"start_window", "end_window", "text"} to override prompts per window. + height (`int`, defaults to `512`): + Output image height in pixels; must be divisible by 32. + width (`int`, defaults to `704`): + Output image width in pixels; must be divisible by 32. + num_frames (`int`, defaults to `161`): + Number of output frames (in decoded pixel space). + frame_rate (`float`, defaults to `25`): + Frames-per-second; used to normalize temporal coordinates in `video_coords`. + guidance_scale (`float`, defaults to `1.0`): + CFG scale; values > 1 enable classifier-free guidance. + guidance_rescale (`float`, defaults to `0.0`): + Optional rescale to mitigate overexposure under CFG (see `rescale_noise_cfg`). + num_inference_steps (`int`, *optional*, defaults to `8`): + Denoising steps per window. Ignored if `sigmas` is provided. + sigmas (`List[float]` or `torch.Tensor`, *optional*): + Explicit sigma schedule per window; if set, overrides `num_inference_steps`. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + Controls stochasticity; list accepted but first element is used (batch=1). + seed (`int`, *optional*, defaults to `0`): + If provided, seeds the shared generator for global latents and derives a window-local generator with + `seed + w_start` per temporal window. + cond_image (`PIL.Image.Image` or `torch.Tensor`, *optional*): + Conditioning image; fixes frame 0 via per-token mask when `cond_strength > 0`. + cond_strength (`float`, defaults to `0.5`): + Strength of first-frame hard conditioning (smaller cond_mask ⇒ stronger preservation). + latents (`torch.Tensor`, *optional*): + Initial latents [B, C_lat, F_lat, H_lat, W_lat]; if None, sampled with `randn_tensor`. + temporal_tile_size (`int`, defaults to `80`): + Temporal window size (in decoded frames); internally scaled by VAE temporal compression. + temporal_overlap (`int`, defaults to `24`): + Overlap between consecutive windows (in decoded frames); internally scaled by compression. + temporal_overlap_cond_strength (`float`, defaults to `0.5`): + Strength for injecting previous window tail latents at new window head. + adain_factor (`float`, defaults to `0.25`): + AdaIN normalization strength for cross-window consistency (0 disables). + guidance_latents (`torch.Tensor`, *optional*): + Reference latents injected at window head; length trimmed by overlap for subsequent windows. + guiding_strength (`float`, defaults to `1.0`): + Injection strength for `guidance_latents`. + negative_index_latents (`torch.Tensor`, *optional*): + A single-frame latent appended at window head for "negative index" semantics. + negative_index_strength (`float`, defaults to `1.0`): + Injection strength for `negative_index_latents`. + skip_steps_sigma_threshold (`float`, *optional*, defaults to `1`): + Skip steps whose sigma exceeds this threshold. + decode_timestep (`float`, *optional*, defaults to `0.05`): + Decode-time timestep (if VAE supports timestep_conditioning). + decode_noise_scale (`float`, *optional*, defaults to `0.025`): + Decode-time noise mix scale (if VAE supports timestep_conditioning). + decode_horizontal_tiles (`int`, defaults to `4`): + Number of horizontal tiles during VAE decoding. + decode_vertical_tiles (`int`, defaults to `4`): + Number of vertical tiles during VAE decoding. + decode_overlap (`int`, defaults to `3`): + Overlap (in latent pixels) between tiles during VAE decoding. + output_type (`str`, *optional*, defaults to `"latent"`): + The output format of the generated video. Choose between "latent", "pt", "np", or "pil". If "latent", + returns latents without decoding. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Extra attention parameters forwarded to the transformer. + callback_on_step_end (`PipelineCallback` or `MultiPipelineCallbacks`, *optional*): + Per-step callback hook. + callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`): + Keys from locals() to pass into the callback. + max_sequence_length (`int`, defaults to `128`): + Tokenizer max length for prompt encoding. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated frames. The output format depends on + `output_type`: + - "latent"/"pt": `torch.Tensor` [B, C, F, H, W]; "latent" is in normalized latent space, "pt" is VAE + output space. + - "np": `np.ndarray` post-processed. + - "pil": `List[PIL.Image.Image]` list of PIL images. + + Shapes: + Latent sizes (when auto-generated): + - F_lat = (num_frames - 1) // vae_temporal_compression_ratio + 1 + - H_lat = height // vae_spatial_compression_ratio + - W_lat = width // vae_spatial_compression_ratio + + Notes: + - Seeding: when `seed` is provided, each temporal window uses a local generator seeded with `seed + + w_start`, while the shared generator is seeded once for global latents if no generator is passed; + otherwise the passed-in generator is reused. + - CFG: unified `noise_pred = uncond + w * (text - uncond)` with optional `guidance_rescale`. + - Memory: denoising performs full-frame predictions (no spatial tiling); decoding can be tiled to avoid + OOM. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Input validation: height/width must be divisible by 32 + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 1. Device & generator + device = self._execution_device + # Normalize generator input: accept list but use the first (batch_size=1) + if isinstance(generator, list): + generator = generator[0] + if seed is not None and generator is None: + generator = torch.Generator(device=device).manual_seed(seed) + + # 2. Optional i2v first frame conditioning: encode cond_image and inject at frame 0 via prepare_latents + cond_latents = None + if cond_image is not None and cond_strength > 0: + img = self.video_processor.preprocess(cond_image, height=height, width=width) + img = img.to(device=device, dtype=self.vae.dtype) + enc = self.vae.encode(img.unsqueeze(2)) # [B, C, 1, h, w] + cond_latents = enc.latent_dist.mode() if hasattr(enc, "latent_dist") else enc.latents + cond_latents = cond_latents.to(torch.float32) + cond_latents = self._normalize_latents( + cond_latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + # 3. Global initial latents [B,C,F,H,W], optionally seeded/conditioned + latents, negative_index_latents, latent_num_frames, latent_height, latent_width = self.prepare_latents( + batch_size=1, + num_channels_latents=self.transformer.config.in_channels, + height=height, + width=width, + num_frames=num_frames, + device=device, + generator=generator, + dtype=torch.float32, + latents=latents, + cond_latents=cond_latents, + cond_strength=cond_strength, + negative_index_latents=negative_index_latents, + ) + if guidance_latents is not None: + guidance_latents = guidance_latents.to(device=device, dtype=torch.float32) + if latents.shape[2] != guidance_latents.shape[2]: + raise ValueError("The number of frames in `latents` and `guidance_latents` must be the same") + + # 4. Sliding windows in latent frames + tile_size_lat = max(1, temporal_tile_size // self.vae_temporal_compression_ratio) + overlap_lat = max(0, temporal_overlap // self.vae_temporal_compression_ratio) + windows = split_into_temporal_windows( + latent_num_frames, tile_size_lat, overlap_lat, self.vae_temporal_compression_ratio + ) + + # 5. Multi-prompt segments parsing + segment_texts = parse_prompt_segments(prompt, prompt_segments) + + out_latents = None + first_window_latents = None + + # 6. Process each temporal window + for w_idx, (w_start, w_end) in enumerate(windows): + if self.interrupt: + break + + # 6.1 Encode prompt embeddings per window segment + seg_index = min(w_idx, len(segment_texts) - 1) if segment_texts else 0 + pos_text = segment_texts[seg_index] if segment_texts else (prompt if isinstance(prompt, str) else "") + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=[pos_text], + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=1, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + max_sequence_length=max_sequence_length, + device=device, + dtype=None, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 6.2 Window-level timesteps reset: fresh sampling for each temporal window + if sigmas is not None: + s = torch.tensor(sigmas, dtype=torch.float32) if not isinstance(sigmas, torch.Tensor) else sigmas + self.scheduler.set_timesteps(sigmas=s, device=device) + self._num_timesteps = len(sigmas) + else: + self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device) + self._num_timesteps = num_inference_steps + + # 6.3 Extract window latents [B,C,T,H,W] + window_latents = latents[:, :, w_start:w_end] + window_guidance_latents = guidance_latents[:, :, w_start:w_end] if guidance_latents is not None else None + window_T = window_latents.shape[2] + + # 6.4 Build per-window cond mask and inject previous tails / reference + window_cond_mask_5d = torch.ones( + (1, 1, window_T, latent_height, latent_width), device=device, dtype=torch.float32 + ) + self._current_tile_T = window_T + prev_overlap_len = 0 + # Inter-window tail latent injection (Extend) + if w_idx > 0 and overlap_lat > 0 and out_latents is not None: + k = min(overlap_lat, out_latents.shape[2]) + prev_tail = out_latents[:, :, -k:] + window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents( + window_latents, + prev_tail, + window_cond_mask_5d, + overlap_lat, + temporal_overlap_cond_strength, + prev_overlap_len, + ) + # Reference/negative-index latent injection (append 1 frame at window head; controlled by negative_index_strength) + if window_guidance_latents is not None: + guiding_len = ( + window_guidance_latents.shape[2] if w_idx == 0 else window_guidance_latents.shape[2] - overlap_lat + ) + window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents( + window_latents, + window_guidance_latents[:, :, -guiding_len:], + window_cond_mask_5d, + guiding_len, + guiding_strength, + prev_overlap_len, + ) + else: + guiding_len = 0 + window_latents, window_cond_mask_5d, prev_overlap_len = inject_prev_tail_latents( + window_latents, + negative_index_latents, + window_cond_mask_5d, + 1, + negative_index_strength, + prev_overlap_len, + ) + if w_idx == 0 and cond_image is not None and cond_strength > 0: + # First-frame I2V: smaller mask means stronger preservation of the original latent + window_cond_mask_5d[:, :, 0] = 1.0 - cond_strength + + # Update effective window latent sizes (consider injections on T/H/W) + w_B, w_C, w_T_eff, w_H_eff, w_W_eff = window_latents.shape + p = self.transformer_spatial_patch_size + pt = self.transformer_temporal_patch_size + + # 6.5 Pack full-window latents/masks once + # Seeding policy: derive a window-local generator to decouple RNG across windows + if seed is not None: + tile_seed = int(seed) + int(w_start) + local_gen = torch.Generator(device=device).manual_seed(tile_seed) + else: + local_gen = generator + # randn*mask + (1-mask)*latents implements hard-condition initialization + init_rand = randn_tensor(window_latents.shape, generator=local_gen, device=device, dtype=torch.float32) + mixed_latents = init_rand * window_cond_mask_5d + (1 - window_cond_mask_5d) * window_latents + window_latents_packed = self._pack_latents( + window_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + latents_packed = self._pack_latents( + mixed_latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + cond_mask_tokens = self._pack_latents( + window_cond_mask_5d, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if self.do_classifier_free_guidance: + cond_mask = torch.cat([cond_mask_tokens, cond_mask_tokens], dim=0) + else: + cond_mask = cond_mask_tokens + + # 6.6 Denoising loop per full window (no spatial tiling) + sigmas_current = self.scheduler.sigmas.to(device=latents_packed.device) + if sigmas_current.shape[0] >= 2: + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[:-1])): + if self.interrupt: + break + # Skip semantics: if sigma exceeds threshold, skip this step (do not call scheduler.step) + sigma_val = float(sigmas_current[i].item()) + if skip_steps_sigma_threshold is not None and float(skip_steps_sigma_threshold) > 0.0: + if sigma_val > float(skip_steps_sigma_threshold): + continue + + self._current_timestep = t + + # Model input (stack 2 copies under CFG) + latent_model_input = ( + torch.cat([latents_packed] * 2) if self.do_classifier_free_guidance else latents_packed + ) + # Broadcast timesteps, combine with per-token cond mask (I2V at window head) + timestep = t.expand(latent_model_input.shape[0]) + if cond_mask is not None: + # Broadcast timestep to per-token mask under CFG: [B] -> [B, S, 1] + timestep = timestep[:, None, None] * cond_mask + + # Micro-conditions: only provide video_coords (num_frames/height/width set to 1) + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Inpainting pre-blend (ComfyUI parity: KSamplerX0Inpaint:400) + if cond_mask_tokens is not None: + latents_packed = latents_packed * cond_mask_tokens + window_latents_packed * ( + 1.0 - cond_mask_tokens + ) + + # Negative-index/overlap lengths (for segmenting time coordinates; RoPE-compatible) + k_negative_count = ( + 1 if (negative_index_latents is not None and float(negative_index_strength) > 0.0) else 0 + ) + k_overlap_count = overlap_lat if (w_idx > 0 and overlap_lat > 0) else 0 + video_coords = build_video_coords_for_window( + latents=window_latents, + overlap_len=int(k_overlap_count), + guiding_len=int(guiding_len), + negative_len=int(k_negative_count), + rope_interpolation_scale=rope_interpolation_scale, + frame_rate=frame_rate, + ) + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input.to(dtype=self.transformer.dtype), + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=1, + height=1, + width=1, + rope_interpolation_scale=rope_interpolation_scale, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + # Unified CFG + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + if self.guidance_rescale > 0: + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale + ) + + # Use global timestep for scheduling, but apply suppressive blending with hard-condition tokens (e.g., first frame) after step to avoid brightness/flicker due to time misalignment + latents_packed = self.scheduler.step( + noise_pred, t, latents_packed, generator=local_gen, return_dict=False + )[0] + # Inpainting post-blend (ComfyUI parity: restore hard-conditioned regions after update) + if cond_mask_tokens is not None: + latents_packed = latents_packed * cond_mask_tokens + window_latents_packed * ( + 1.0 - cond_mask_tokens + ) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents_packed = callback_outputs.pop("latents", latents_packed) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if XLA_AVAILABLE: + xm.mark_step() + else: + # Not enough sigmas to perform a valid step; skip this window safely. + pass + + # 6.7 Unpack back to [B,C,T,H,W] once + window_out = self._unpack_latents( + latents_packed, + w_T_eff, + w_H_eff, + w_W_eff, + p, + pt, + ) + if prev_overlap_len > 0: + window_out = window_out[:, :, :-prev_overlap_len] + + # 6.8 Overlap handling and fusion + if out_latents is None: + # First window: keep all latent frames and cache as AdaIN reference + out_latents = window_out + first_window_latents = out_latents + else: + window_out = window_out[:, :, 1:] # Drop the first frame of the new window + if adain_factor > 0 and first_window_latents is not None: + window_out = adain_normalize_latents(window_out, first_window_latents, adain_factor) + overlap_len = max(overlap_lat - 1, 1) + prev_tail_chunk = out_latents[:, :, -window_out.shape[2] :] + fused = linear_overlap_fuse(prev_tail_chunk, window_out, overlap_len) + out_latents = torch.cat([out_latents[:, :, : -window_out.shape[2]], fused], dim=2) + + # 7. Decode or return latent + if output_type == "latent": + video = out_latents + else: + # Decode via tiling to avoid OOM from full-frame decoding; latents are already de-normalized, so keep auto_denormalize disabled + video = self.vae_decode_tiled( + out_latents, + decode_timestep=decode_timestep, + decode_noise_scale=decode_noise_scale, + horizontal_tiles=int(decode_horizontal_tiles), + vertical_tiles=int(decode_vertical_tiles), + overlap=int(decode_overlap), + generator=generator, + output_type=output_type, # Keep type consistent; postprocess is applied afterwards + ) + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index f30f8a3dc8f6..3226b045cccb 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -798,10 +798,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, timesteps, sigmas=sigmas, mu=mu, diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py new file mode 100644 index 000000000000..115e83e827a4 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["connectors"] = ["LTX2TextConnectors"] + _import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"] + _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] + _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] + _import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"] + _import_structure["vocoder"] = ["LTX2Vocoder"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .connectors import LTX2TextConnectors + from .latent_upsampler import LTX2LatentUpsamplerModel + from .pipeline_ltx2 import LTX2Pipeline + from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline + from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline + from .vocoder import LTX2Vocoder + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py new file mode 100644 index 000000000000..22ca42d37902 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -0,0 +1,326 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.attention import FeedForward +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ): + super().__init__() + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + + def forward( + self, + batch_size: int, + pos: int, + device: Union[str, torch.device], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + rope_type: str = "interleaved", + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + processor=LTX2AudioVideoAttnProcessor(), + rope_type=rope_type, + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] + hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. + """ + + @register_to_config + def __init__( + self, + caption_channels: int, + text_proj_in_factor: int, + video_connector_num_attention_heads: int, + video_connector_attention_head_dim: int, + video_connector_num_layers: int, + video_connector_num_learnable_registers: int | None, + audio_connector_num_attention_heads: int, + audio_connector_attention_head_dim: int, + audio_connector_num_layers: int, + audio_connector_num_learnable_registers: int | None, + connector_rope_base_seq_len: int, + rope_theta: float, + rope_double_precision: bool, + causal_temporal_positioning: bool, + rope_type: str = "interleaved", + ): + super().__init__() + self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + + def forward( + self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False + ): + # Convert to additive attention mask, if necessary + if not additive_mask: + text_dtype = text_encoder_hidden_states.dtype + attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) + + return video_text_embedding, audio_text_embedding, new_attn_mask diff --git a/src/diffusers/pipelines/ltx2/export_utils.py b/src/diffusers/pipelines/ltx2/export_utils.py new file mode 100644 index 000000000000..0bc7a59db228 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/export_utils.py @@ -0,0 +1,134 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fractions import Fraction +from typing import Optional + +import torch + +from ...utils import is_av_available + + +_CAN_USE_AV = is_av_available() +if _CAN_USE_AV: + import av +else: + raise ImportError( + "PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`" + ) + + +def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream: + """ + Prepare the audio stream for writing. + """ + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _resample_audio( + container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame +) -> None: + cc = audio_stream.codec_context + + # Use the encoder's format/layout/rate as the *target* + target_format = cc.format or "fltp" # AAC → usually fltp + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + audio_resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in audio_resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + # flush audio encoder + for packet in audio_stream.encode(): + container.mux(packet) + + +def _write_audio( + container: av.container.Container, + audio_stream: av.audio.AudioStream, + samples: torch.Tensor, + audio_sample_rate: int, +) -> None: + if samples.ndim == 1: + samples = samples[:, None] + + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + + if samples.shape[1] != 2: + raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") + + # Convert to int16 packed for ingestion; resampler converts to encoder fmt. + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + _resample_audio(container, audio_stream, frame_in) + + +def encode_video( + video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str +) -> None: + video_np = video.cpu().numpy() + + _, height, width, _ = video_np.shape + + container = av.open(output_path, mode="w") + stream = container.add_stream("libx264", rate=int(fps)) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + + if audio is not None: + if audio_sample_rate is None: + raise ValueError("audio_sample_rate is required when audio is provided") + + audio_stream = _prepare_audio_stream(container, audio_sample_rate) + + for frame_array in video_np: + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + + # Flush encoder + for packet in stream.encode(): + container.mux(packet) + + if audio is not None: + _write_audio(container, audio_stream, audio, audio_sample_rate) + + container.close() diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py new file mode 100644 index 000000000000..69a9b1d9193f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py @@ -0,0 +1,285 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +RATIONAL_RESAMPLER_SCALE_MAPPING = { + 0.75: (3, 4), + 1.5: (3, 2), + 2.0: (2, 1), + 4.0: (4, 1), +} + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock +class ResBlock(torch.nn.Module): + def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = torch.nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = torch.nn.GroupNorm(32, channels) + self.activation = torch.nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.norm2(hidden_states) + hidden_states = self.activation(hidden_states + residual) + return hidden_states + + +# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND +class PixelShuffleND(torch.nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + + self.dims = dims + self.upscale_factors = upscale_factors + + if dims not in [1, 2, 3]: + raise ValueError("dims must be 1, 2, or 3") + + def forward(self, x): + if self.dims == 3: + # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:3])) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(6, 7) + .flatten(4, 5) + .flatten(2, 3) + ) + elif self.dims == 2: + # spatial: b (c p1 p2) h w -> b c (h p1) (w p2) + return ( + x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3) + ) + elif self.dims == 1: + # temporal: b (c p1) f h w -> b c (f p1) h w + return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3) + + +class BlurDownsample(torch.nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W. + Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: + super().__init__() + + if dims not in (2, 3): + raise ValueError(f"`dims` must be either 2 or 3 but is {dims}") + if kernel_size < 3 or kernel_size % 2 != 1: + raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}") + + self.dims = dims + self.stride = stride + self.kernel_size = kernel_size + + # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from + # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and + # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). + # The 2D kernel is constructed as the outer product and normalized. + k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + if self.dims == 2: + c = x.shape[1] + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + else: + # dims == 3: apply per-frame on H,W + b, c, f, _, _ = x.shape + x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + + h2, w2 = x.shape[-2:] + x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class SpatialRationalResampler(torch.nn.Module): + """ + Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample + by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the + input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the + (integer) denominator. + """ + + def __init__(self, mid_channels: int = 1024, scale: float = 2.0): + super().__init__() + self.scale = float(scale) + num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None) + if num_denom is None: + raise ValueError( + f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}" + ) + self.num, self.den = num_denom + + self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Expected x shape: [B * F, C, H, W] + # b, _, f, h, w = x.shape + # x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + # x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W] + return x + + +class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`, defaults to `128`): + Number of channels in the input latent + mid_channels (`int`, defaults to `512`): + Number of channels in the middle layers + num_blocks_per_stage (`int`, defaults to `4`): + Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`, defaults to `3`): + Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`, defaults to `True`): + Whether to spatially upsample the latent + temporal_upsample (`bool`, defaults to `False`): + Whether to temporally upsample the latent + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 1024, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + rational_spatial_scale: Optional[float] = 2.0, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + + ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d + + self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = torch.nn.GroupNorm(32, mid_channels) + self.initial_activation = torch.nn.SiLU() + + self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) + + if spatial_upsample and temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_spatial_scale is not None: + self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale) + else: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = torch.nn.Sequential( + torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError("Either spatial_upsample or temporal_upsample must be True") + + self.post_upsample_res_blocks = torch.nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.dims == 2: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.upsampler(hidden_states) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + else: + hidden_states = self.initial_conv(hidden_states) + hidden_states = self.initial_norm(hidden_states) + hidden_states = self.initial_activation(hidden_states) + + for block in self.res_blocks: + hidden_states = block(hidden_states) + + if self.temporal_upsample: + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states[:, :, 1:, :, :] + else: + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + for block in self.post_upsample_res_blocks: + hidden_states = block(hidden_states) + + hidden_states = self.final_conv(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py new file mode 100644 index 000000000000..a92a7a2c8869 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -0,0 +1,1224 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 0.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py new file mode 100644 index 000000000000..04d7ee89c52a --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -0,0 +1,1308 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin +from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from ...models.transformers import LTX2VideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors +from .pipeline_output import LTX2PipelineOutput +from .vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: Union[float, torch.Tensor], generator: Optional[torch.Generator] = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: Optional[int] = None, + patch_size_t: Optional[int] = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + def prepare_latents( + self, + image: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + noise_scale: float = 0.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Interpolation. + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: Optional[List[float]] = None, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + audio_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py new file mode 100644 index 000000000000..340efd10f24f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py @@ -0,0 +1,428 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch + +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLLTX2Video +from ...utils import get_logger, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..ltx.pipeline_output import LTXPipelineOutput +from ..pipeline_utils import DiffusionPipeline +from .latent_upsampler import LTX2LatentUpsamplerModel + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + >>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel + >>> from diffusers.utils import load_image + + >>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="pil", + ... return_dict=False, + ... ) + + >>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained( + ... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16 + ... ) + >>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler) + >>> upsample_pipe.vae.enable_tiling() + >>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16) + + >>> video = upsample_pipe( + ... video=video, + ... width=768, + ... height=512, + ... output_type="np", + ... return_dict=False, + ... )[0] + >>> video = (video * 255).round().astype("uint8") + >>> video = torch.from_numpy(video) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTX2LatentUpsamplePipeline(DiffusionPipeline): + model_cpu_offload_seq = "vae->latent_upsampler" + + def __init__( + self, + vae: AutoencoderKLLTX2Video, + latent_upsampler: LTX2LatentUpsamplerModel, + ) -> None: + super().__init__() + + self.register_modules(vae=vae, latent_upsampler=latent_upsampler) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + + def prepare_latents( + self, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_frames: int = 121, + height: int = 512, + width: int = 768, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 3: + # Convert token seq [B, S, D] to latent video [B, C, F, H, W] + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + latents = self._unpack_latents( + latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size + ) + return latents.to(device=device, dtype=dtype) + + video = video.to(device=device, dtype=self.vae.dtype) + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + # NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here + # init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + return init_latents + + def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0): + """ + Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent + tensor. + + Args: + latent (`torch.Tensor`): + Input latents to normalize + reference_latents (`torch.Tensor`): + The reference latents providing style statistics. + factor (`float`): + Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0 + + Returns: + torch.Tensor: The transformed latent tensor + """ + result = latents.clone() + + for i in range(latents.size(0)): + for c in range(latents.size(1)): + r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order + i_sd, i_mean = torch.std_mean(result[i, c], dim=None) + + result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean + + result = torch.lerp(latents, result, factor) + return result + + def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor: + """ + Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually + smooth way using a sigmoid-based compression. + + This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially + when controlling dynamic behavior with a `compression` factor. + + Args: + latents : torch.Tensor + Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range. + compression : float + Compression strength in the range [0, 1]. + - 0.0: No tone-mapping (identity transform) + - 1.0: Full compression effect + + Returns: + torch.Tensor + The tone-mapped latent tensor of the same shape as input. + """ + # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot + scale_factor = compression * 0.75 + abs_latents = torch.abs(latents) + + # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0 + # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect + sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0)) + scales = 1.0 - 0.8 * scale_factor * sigmoid_term + + filtered = latents * scales + return filtered + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + def check_inputs(self, video, height, width, latents, tone_map_compression_ratio): + if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` can be provided.") + if video is None and latents is None: + raise ValueError("One of `video` or `latents` has to be provided.") + + if not (0 <= tone_map_compression_ratio <= 1): + raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: Optional[List[PipelineImageInput]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + spatial_patch_size: int = 1, + temporal_patch_size: int = 1, + latents: Optional[torch.Tensor] = None, + latents_normalized: bool = False, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, + adain_factor: float = 0.0, + tone_map_compression_ratio: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + video (`List[PipelineImageInput]`, *optional*) + The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be + supplied. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the input video (not the generated video, which will have a larger resolution). + width (`int`, *optional*, defaults to `768`): + The width in pixels of the input video (not the generated video, which will have a larger resolution). + num_frames (`int`, *optional*, defaults to `121`): + The number of frames in the input video. + spatial_patch_size (`int`, *optional*, defaults to `1`): + The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary. + temporal_patch_size (`int`, *optional*, defaults to `1`): + The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is + necessary. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a + patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size, + latent_channels, latent_frames, latent_height, latent_width)`. + latents_normalized (`bool`, *optional*, defaults to `False`) + If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If + `True`, the `latents` will be denormalized before being supplied to the latent upsampler. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + adain_factor (`float`, *optional*, defaults to `0.0`): + Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents. + Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed. + tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`): + The compression strength for tone mapping, which will reduce the dynamic range of the latent values. + This is useful for regularizing high-variance latents or for conditioning outputs during generation. + Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to + the full compression effect. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is the upsampled video. + """ + + self.check_inputs( + video=video, + height=height, + width=width, + latents=latents, + tone_map_compression_ratio=tone_map_compression_ratio, + ) + + if video is not None: + # Batched video input is not yet tested/supported. TODO: take a look later + batch_size = 1 + else: + batch_size = latents.shape[0] + device = self._execution_device + + if video is not None: + num_frames = len(video) + if num_frames % self.vae_temporal_compression_ratio != 1: + num_frames = ( + num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1 + ) + video = video[:num_frames] + logger.warning( + f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames." + ) + video = self.video_processor.preprocess_video(video, height=height, width=width) + video = video.to(device=device, dtype=torch.float32) + + latents_supplied = latents is not None + latents = self.prepare_latents( + video=video, + batch_size=batch_size, + num_frames=num_frames, + height=height, + width=width, + spatial_patch_size=spatial_patch_size, + temporal_patch_size=temporal_patch_size, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents, + ) + + if latents_supplied and latents_normalized: + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(self.latent_upsampler.dtype) + latents_upsampled = self.latent_upsampler(latents) + + if adain_factor > 0.0: + latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + if tone_map_compression_ratio > 0.0: + latents = self.tone_map_latents(latents, tone_map_compression_ratio) + + if output_type == "latent": + video = latents + else: + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx2/pipeline_output.py b/src/diffusers/pipelines/ltx2/pipeline_output.py new file mode 100644 index 000000000000..eacd571125b0 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LTX2PipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + audio (`torch.Tensor`, `np.ndarray`): + TODO + """ + + frames: torch.Tensor + audio: torch.Tensor diff --git a/src/diffusers/pipelines/ltx2/utils.py b/src/diffusers/pipelines/ltx2/utils.py new file mode 100644 index 000000000000..f80469817fe6 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/utils.py @@ -0,0 +1,6 @@ +# Pre-trained sigma values for distilled model are taken from +# https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-pipelines/src/ltx_pipelines/utils/constants.py +DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875] + +# Reduced schedule for super-resolution stage 2 (subset of distilled values) +STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875] diff --git a/src/diffusers/pipelines/ltx2/vocoder.py b/src/diffusers/pipelines/ltx2/vocoder.py new file mode 100644 index 000000000000..217c68103e39 --- /dev/null +++ b/src/diffusers/pipelines/ltx2/vocoder.py @@ -0,0 +1,159 @@ +import math +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: Tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4], + upsample_factors: List[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: List[int] = [3, 7, 11], + resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_negative_slope: float = 0.1, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states diff --git a/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py index 69f69d5768a8..8065a17b7889 100644 --- a/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py +++ b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py @@ -14,7 +14,7 @@ # limitations under the License. # # Modifications by Decart AI Team: -# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension. +# - Based on pipeline_wan.py, but with supports receiving a condition video appended to the channel dimension. import html from typing import Any, Callable, Dict, List, Optional, Tuple, Union diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index b59c265646cd..f4711cf9d9d8 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -799,7 +799,13 @@ def __call__( prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) # 5. Prepare latents. latent_channels = self.transformer.config.in_channels diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 937803edbcbc..8151b29b25fd 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -704,10 +704,14 @@ def __call__( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, mu=mu, ) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 5874a92c6f2f..19a36c73f9ed 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -668,10 +668,14 @@ def __call__( sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) sigmas = np.array(sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, timesteps, sigmas, ) diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 090cb46aace4..96c209813f54 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -459,8 +459,12 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps] + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas=sigmas ) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index 1abef014301a..389927aafcbc 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -1131,8 +1131,12 @@ def __call__( assert False # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 381352ccc5d4..8b7df89f039c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -1329,8 +1329,12 @@ def __call__( assert False # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index df5b3f5c10a5..5a6b8d5e9f37 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -85,7 +85,6 @@ >>> from diffusers import ControlNetModel, StableDiffusionXLControlNetPAGImg2ImgPipeline, AutoencoderKL >>> from diffusers.utils import load_image - >>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") >>> feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas") >>> controlnet = ControlNetModel.from_pretrained( diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index d156eac8f3f7..6704924b2512 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -17,7 +17,7 @@ import numpy as np import torch -from transformers import BertModel, BertTokenizer, CLIPImageProcessor, MT5Tokenizer, T5EncoderModel +from transformers import BertModel, BertTokenizer, CLIPImageProcessor, T5EncoderModel, T5Tokenizer from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -173,7 +173,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): The HunyuanDiT model designed by Tencent Hunyuan. text_encoder_2 (`T5EncoderModel`): The mT5 embedder. Specifically, it is 't5-v1_1-xxl'. - tokenizer_2 (`MT5Tokenizer`): + tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. scheduler ([`DDPMScheduler`]): A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents. @@ -208,7 +208,7 @@ def __init__( feature_extractor: Optional[CLIPImageProcessor] = None, requires_safety_checker: bool = True, text_encoder_2: Optional[T5EncoderModel] = None, - tokenizer_2: Optional[MT5Tokenizer] = None, + tokenizer_2: Optional[T5Tokenizer] = None, pag_applied_layers: Union[str, List[str]] = "blocks.1", # "blocks.16.attn1", "blocks.16", "16", 16 ): super().__init__() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py index 1403be03a620..5b82d546445b 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py @@ -905,8 +905,12 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index 9031877b5b8d..283862989c71 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -764,8 +764,12 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 9e91ccbe8006..466996889417 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -856,8 +856,12 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py index ea64f8be2c50..67676fb28798 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -952,8 +952,12 @@ def __call__( ip_adapter_image_embeds[i] = image_embeds # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index 941b675099b9..303a0a2f0b2e 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -888,7 +888,13 @@ def __call__( pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index f40dd52fc244..2005c865c22b 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -951,7 +951,13 @@ def __call__( image = self.image_processor.preprocess(image, height=height, width=width) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py index 8351112ce409..42b5db0fa762 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py @@ -986,8 +986,12 @@ def __call__( image = self.image_processor.preprocess(image) # 5. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py index 6b1b294e10f5..cf8c4972762f 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py @@ -1094,8 +1094,12 @@ def __call__( ) # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index a69f06536a55..0613ec23f740 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -1095,8 +1095,12 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 416d9e5677b4..1081993f46e6 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -1272,8 +1272,12 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index 6be341e07b1a..f6c4982c1c6c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -1392,8 +1392,12 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8868e942ce3d..57d4eaa8f89e 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -758,6 +758,7 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, + disable_mmap: bool, quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" @@ -801,12 +802,6 @@ def load_sub_model( # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) loading_kwargs = {} - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - loading_kwargs["provider_options"] = provider_options is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) @@ -821,6 +816,17 @@ def load_sub_model( and transformers_version >= version.parse("4.20.0") ) + # For transformers models >= 4.56.0, use 'dtype' instead of 'torch_dtype' to avoid deprecation warnings + if issubclass(class_obj, torch.nn.Module): + if is_transformers_model and transformers_version >= version.parse("4.56.0"): + loading_kwargs["dtype"] = torch_dtype + else: + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + loading_kwargs["provider_options"] = provider_options + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading. @@ -854,6 +860,9 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + if is_diffusers_model: + loading_kwargs["disable_mmap"] = disable_mmap + if is_transformers_model and is_transformers_version(">=", "4.57.0"): loading_kwargs.pop("offload_state_dict") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..b96305c74131 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -60,6 +60,7 @@ deprecate, is_accelerate_available, is_accelerate_version, + is_bitsandbytes_version, is_hpu_available, is_torch_npu_available, is_torch_version, @@ -67,6 +68,7 @@ logging, numpy_to_pil, ) +from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module @@ -443,7 +445,10 @@ def module_is_sequentially_offloaded(module): _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module) - if is_loaded_in_8bit_bnb: + # https://github.com/huggingface/accelerate/pull/3907 + if is_loaded_in_8bit_bnb and ( + is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0") + ): return False return hasattr(module, "_hf_hook") and ( @@ -522,9 +527,10 @@ def module_is_offloaded(module): f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." ) - if is_loaded_in_8bit_bnb and device is not None: + if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"): logger.warning( f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." + "You need to upgrade bitsandbytes to at least 0.48.0" ) # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling @@ -541,6 +547,14 @@ def module_is_offloaded(module): # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) + # added here https://github.com/huggingface/transformers/pull/43258 + if ( + is_loaded_in_8bit_bnb + and device is not None + and is_transformers_version(">", "4.58.0") + and is_bitsandbytes_version(">=", "0.48.0") + ): + module.to(device=device) elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded: module.to(device, dtype) @@ -707,6 +721,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loading `from_flax`. dduf_file(`str`, *optional*): Load weights from the specified dduf file. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf > auth login`. @@ -758,6 +775,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) + disable_mmap = kwargs.pop("disable_mmap", False) if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -982,7 +1000,11 @@ def load_module(name, value): # 7. Load each module in the pipeline current_device_map = None _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) - for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): + logging_tqdm_kwargs = {"desc": "Loading pipeline components..."} + if not is_torch_dist_rank_zero(): + logging_tqdm_kwargs["disable"] = True + + for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs): # 7.1 device_map shenanigans if final_device_map is not None: if isinstance(final_device_map, dict) and len(final_device_map) > 0: @@ -1041,6 +1063,7 @@ def load_module(name, value): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, + disable_mmap=disable_mmap, quantization_config=quantization_config, ) logger.info( @@ -1218,7 +1241,9 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t # This is because the model would already be placed on a CUDA device. _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model) - if is_loaded_in_8bit_bnb: + if is_loaded_in_8bit_bnb and ( + is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0") + ): logger.info( f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." ) @@ -1908,10 +1933,14 @@ def progress_bar(self, iterable=None, total=None): f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) + progress_bar_config = dict(self._progress_bar_config) + if "disable" not in progress_bar_config: + progress_bar_config["disable"] = not is_torch_dist_rank_zero() + if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) + return tqdm(iterable, **progress_bar_config) elif total is not None: - return tqdm(total=total, **self._progress_bar_config) + return tqdm(total=total, **progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 1d718a4852a4..2ecc13ef71bf 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -862,8 +862,12 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index bb169ac5c443..f53d3c5630f0 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -806,8 +806,12 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py index 2400632ba2bd..3f43d0ebb0b9 100644 --- a/src/diffusers/pipelines/qwenimage/__init__.py +++ b/src/diffusers/pipelines/qwenimage/__init__.py @@ -31,6 +31,7 @@ _import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"] _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"] _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"] + _import_structure["pipeline_qwenimage_layered"] = ["QwenImageLayeredPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -47,6 +48,7 @@ from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline + from .pipeline_qwenimage_layered import QwenImageLayeredPipeline else: import sys diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 33dc2039b986..9938ae95fd55 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -254,13 +254,17 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -307,15 +311,6 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") @@ -672,11 +667,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -695,7 +685,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -709,7 +698,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 5111096d93c1..e0cb9924cda3 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -321,8 +321,13 @@ def encode_prompt( _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -369,15 +374,6 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") @@ -909,7 +905,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -920,7 +915,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -935,7 +929,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 102a813ab582..f8521318e630 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -305,6 +305,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( @@ -852,7 +855,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, ) @@ -863,7 +865,6 @@ def __call__( encoder_hidden_states=prompt_embeds, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, - txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -878,7 +879,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), controlnet_block_samples=controlnet_block_samples, attention_kwargs=self.attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index ed37b238c8c9..353aadcbf08a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -309,6 +309,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( @@ -793,11 +796,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -821,7 +819,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -836,7 +833,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index d54d1881fa4e..bc397f357bf5 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -321,6 +321,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs @@ -375,14 +378,6 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( @@ -1008,11 +1003,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1035,7 +1025,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -1050,7 +1039,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..bc688aeee319 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -323,6 +323,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs @@ -663,6 +666,13 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + # QwenImageEditPlusPipeline does not currently support batch_size > 1 + if batch_size > 1: + raise ValueError( + f"QwenImageEditPlusPipeline currently only supports batch_size=1, but received batch_size={batch_size}. " + "Please process prompts one at a time." + ) + device = self._execution_device # 3. Preprocess image if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): @@ -777,11 +787,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop self.scheduler.set_begin_index(0) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -805,7 +810,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -820,7 +824,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index cb4c5d8016bb..522e1d203051 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -265,7 +265,7 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -297,13 +297,17 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -354,15 +358,6 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") @@ -775,11 +770,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -797,7 +787,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -811,7 +800,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 1915c27eb2bb..960196bec166 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -276,7 +276,7 @@ def get_timesteps(self, num_inference_steps, strength, device): return timesteps, num_inference_steps - t_start - # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -308,13 +308,17 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask @@ -369,14 +373,6 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and prompt_embeds_mask is None: - raise ValueError( - "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." - ) - if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: - raise ValueError( - "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." - ) if padding_mask_crop is not None: if not isinstance(image, PIL.Image.Image): raise ValueError( @@ -944,11 +940,6 @@ def __call__( if self.attention_kwargs is None: self._attention_kwargs = {} - txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None - negative_txt_seq_lens = ( - negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None - ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -966,7 +957,6 @@ def __call__( encoder_hidden_states_mask=prompt_embeds_mask, encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] @@ -980,7 +970,6 @@ def __call__( encoder_hidden_states_mask=negative_prompt_embeds_mask, encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, - txt_seq_lens=negative_txt_seq_lens, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py new file mode 100644 index 000000000000..ea455721e4e5 --- /dev/null +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -0,0 +1,903 @@ +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import QwenImageLoraLoaderMixin +from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import QwenImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageLayeredPipeline + >>> from diffusers.utils import load_image + + >>> pipe = QwenImageLayeredPipeline.from_pretrained("Qwen/Qwen-Image-Layered", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png" + ... ).convert("RGBA") + >>> prompt = "" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> images = pipe( + ... image, + ... prompt, + ... num_inference_steps=50, + ... true_cfg_scale=4.0, + ... layers=4, + ... resolution=640, + ... cfg_normalize=False, + ... use_en_prompt=True, + ... ).images[0] + >>> for i, image in enumerate(images): + ... image.save(f"{i}.out.png") + ``` +""" + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus.calculate_dimensions +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +class QwenImageLayeredPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): + r""" + The Qwen-Image-Layered pipeline for image decomposing. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.vl_processor = processor + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.image_caption_prompt_cn = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# 图像标注器\n你是一个专业的图像标注器。请基于输入图像,撰写图注:\n1. +使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n2. 通过加入以下内容,丰富图注细节:\n - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n - +对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n - 环境细节:例如天气、光照、颜色、纹理、气氛等\n - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n3. +保持真实性与准确性:\n - 不要使用笼统的描述\n - +描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + self.image_caption_prompt_en = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n# Image Annotator\nYou are a professional +image annotator. Please write an image caption based on the input image:\n1. Write the caption using natural, +descriptive language without structured formats or rich text.\n2. Enrich caption details by including: \n - Object +attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n - Vision Relations +between objects, such as spatial relations, functional relations, possessive relations, attachment relations, action +relations, comparative relations, causal relations, and so on\n - Environmental details, such as weather, lighting, +colors, textures, atmosphere, and so on\n - Identify the text clearly visible in the image, without translation or +explanation, and highlight it in the caption with quotation marks\n3. Maintain authenticity and accuracy:\n - Avoid +generalizations\n - Describe all visible information in the image, while do not add information not explicitly shown in +the image\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n""" + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, + padding=True, + return_tensors="pt", + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + if prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask + + def get_image_caption(self, prompt_image, use_en_prompt=True, device=None): + if use_en_prompt: + prompt = self.image_caption_prompt_en + else: + prompt = self.image_caption_prompt_cn + model_inputs = self.vl_processor( + text=prompt, + images=prompt_image, + padding=True, + return_tensors="pt", + ).to(device) + generated_ids = self.text_encoder.generate(**model_inputs, max_new_tokens=512) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) + ] + output_text = self.vl_processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + return output_text.strip() + + def check_inputs( + self, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width, layers): + latents = latents.view(batch_size, layers, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4, 6) + latents = latents.reshape(batch_size, layers * (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, layers, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, layers + 1, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 1, 4, 2, 5, 3, 6) + + latents = latents.reshape(batch_size, layers + 1, channels // (2 * 2), height, width) + latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + layers, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = ( + batch_size, + layers + 1, + num_channels_latents, + height, + width, + ) ### the generated first image is combined image + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = image_latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w) -> (b, f, c, h, w) + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width, 1 + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, layers + 1) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[PipelineImageInput] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + layers: Optional[int] = 4, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + resolution: int = 640, + cfg_normalize: bool = False, + use_en_prompt: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + resolution (`int`, *optional*, defaults to 640): + using different bucket in (640, 1024) to determin the condition and output resolution + cfg_normalize (`bool`, *optional*, defaults to `False`) + whether enable cfg normalization. + use_en_prompt (`bool`, *optional*, defaults to `False`) + automatic caption language if user does not provide caption + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + assert resolution in [640, 1024], f"resolution must be either 640 or 1024, but got {resolution}" + calculated_width, calculated_height = calculate_dimensions( + resolution * resolution, image_size[0] / image_size[1] + ) + height = calculated_height + width = calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + # 2. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + image = image.to(dtype=self.text_encoder.dtype) + + if prompt is None or prompt == "" or prompt == " ": + prompt = self.get_image_caption(prompt_image, use_en_prompt=use_en_prompt, device=device) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + layers, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + *[ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2) + for _ in range(layers + 1) + ], + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + image_seq_len = latents.shape[1] + base_seqlen = 256 * 256 / 16 / 16 + mu = (image_latents.shape[1] / base_seqlen) ** 0.5 + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + is_rgb = torch.tensor([0] * batch_size).to(device=device, dtype=torch.long) + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + additional_t_cond=is_rgb, + return_dict=False, + )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + additional_t_cond=is_rgb, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + if cfg_normalize: + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + noise_pred = comb_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, layers, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + + b, c, f, h, w = latents.shape + + latents = latents[:, :, 1:] # remove the first frame as it is the orgin input + + latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w) + + image = self.vae.decode(latents, return_dict=False)[0] # (b f) c 1 h w + + image = image.squeeze(2) + + image = self.image_processor.postprocess(image, output_type=output_type) + images = [] + for bidx in range(b): + images.append(image[bidx * f : (bidx + 1) * f]) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + + return QwenImagePipelineOutput(images=images) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 2beff802c6e0..33f9de7d20f0 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -927,8 +927,12 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latents. diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index 55ed7b84ebdf..9d5e17c2ed48 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -1010,8 +1010,12 @@ def __call__( raise ValueError("`controlnet` must be of type `SanaControlNetModel`.") # 5. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 6. Prepare latents. diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 04f45f817efb..4c6d2247495d 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -790,10 +790,14 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, timesteps, sigmas=None, max_timesteps=max_timesteps, diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index d6cd7d7feceb..1b1c8ee097c5 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -545,22 +545,24 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 089f92632d38..4bc0d0aaea83 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -887,25 +887,28 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 2951a9447386..3e2004533258 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -966,25 +966,28 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 6fedfc795a40..234ec531b862 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -974,25 +974,28 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index d61b687eadc3..d1df7f5f34cb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -678,24 +678,26 @@ def __call__( latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index cb97f18efeff..d079d2a225cf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -989,8 +989,11 @@ def __call__( ) # 4. Prepare timesteps + timestep_device = device + if XLA_AVAILABLE: + timestep_device = "cpu" timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables @@ -1093,6 +1096,8 @@ def __call__( do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + if XLA_AVAILABLE: + xm.mark_step() image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 95d3ab06f02a..d0be0ee51317 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1050,8 +1050,12 @@ def __call__( image = self.image_processor.preprocess(image) # 5. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 148d7386a732..82902cc7dcd0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1136,8 +1136,12 @@ def __call__( ) # 4. set timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 66d5ffa6b849..a1d0407caf5e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -459,7 +459,6 @@ def __call__( >>> from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline >>> import torch - >>> pipeline = StableDiffusionPipeline.from_pretrained( ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16 ... ) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 16aff102599c..65daafe01237 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -17,7 +17,7 @@ import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel -from ...utils import logging +from ...utils import is_transformers_version, logging logger = logging.get_logger(__name__) @@ -46,6 +46,9 @@ def __init__(self, config: CLIPConfig): self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + # Model requires post_init after transformers v4.57.3 + if is_transformers_version(">", "4.57.3"): + self.post_init() @torch.no_grad() def forward(self, clip_input, images): diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 660d9801df56..fcd108aef4c2 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1025,10 +1025,14 @@ def __call__( scheduler_kwargs["mu"] = mu elif mu is not None: scheduler_kwargs["mu"] = mu + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, - device, + timestep_device, sigmas=sigmas, **scheduler_kwargs, ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 9b11bc8781e7..e6ddbb5544c7 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -1047,8 +1047,12 @@ def __call__( scheduler_kwargs["mu"] = mu elif mu is not None: scheduler_kwargs["mu"] = mu + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas, **scheduler_kwargs ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index b947cbff0914..b1b30efc7da3 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -1167,8 +1167,13 @@ def __call__( scheduler_kwargs["mu"] = mu elif mu is not None: scheduler_kwargs["mu"] = mu + + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas, **scheduler_kwargs ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # check that number of inference steps is not < 1 - as this doesn't make sense diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 295095947a12..6d93e5feab4d 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -1000,7 +1000,13 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e969d2a21a99..3a63bb4f253a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1094,8 +1094,12 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 8d1da8dc102c..d1916b635f92 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1264,8 +1264,12 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 54a1e311804c..fcfddc192b8b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1399,8 +1399,12 @@ def __call__( def denoising_value_valid(dnv): return isinstance(dnv, float) and 0 < dnv < 1 + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 6d9053faaec8..633094239dca 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -544,7 +544,13 @@ def __call__( added_time_ids = added_time_ids.to(device) # 6. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, None, sigmas + ) # 7. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 1ce6987114a7..7b6673cf16f7 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -848,8 +848,12 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 0ea3ba5046cf..bf089bf540ba 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -1130,8 +1130,12 @@ def __call__( ) # 4. Prepare timesteps + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas ) # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 78fe71ea9138..e77a3356b883 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -76,7 +76,8 @@ def basic_clean(text): - text = ftfy.fix_text(text) + if is_ftfy_available(): + text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() @@ -495,6 +496,22 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + patch_size = ( + self.transformer.config.patch_size + if self.transformer is not None + else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. " + f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + if self.config.boundary_ratio is not None and guidance_scale_2 is None: guidance_scale_2 = guidance_scale diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index b7fd0b05980f..c1c4a92c2c39 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -637,6 +637,22 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + patch_size = ( + self.transformer.config.patch_size + if self.transformer is not None + else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper patchification. " + f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + if self.config.boundary_ratio is not None and guidance_scale_2 is None: guidance_scale_2 = guidance_scale diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index a976126da7fe..5475b6e8b479 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -622,7 +622,13 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if XLA_AVAILABLE: + timestep_device = "cpu" + else: + timestep_device = device + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, timestep_device, timesteps + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index f95b3e5a0bed..78bd3bfacbec 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -23,6 +23,10 @@ else: _import_structure["pipeline_output"] = ["ZImagePipelineOutput"] _import_structure["pipeline_z_image"] = ["ZImagePipeline"] + _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] + _import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"] + _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] + _import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -35,7 +39,10 @@ else: from .pipeline_output import ZImagePipelineOutput from .pipeline_z_image import ZImagePipeline - + from .pipeline_z_image_controlnet import ZImageControlNetPipeline + from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline + from .pipeline_z_image_img2img import ZImageImg2ImgPipeline + from .pipeline_z_image_omni import ZImageOmniPipeline else: import sys diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py new file mode 100644 index 000000000000..08fc4da0e7ba --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -0,0 +1,724 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageControlNetPipeline + >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image + >>> from huggingface_hub import hf_hub_download + + >>> controlnet = ZImageControlNetModel.from_single_file( + ... hf_hub_download( + ... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union", + ... filename="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", + ... ), + ... torch_dtype=torch.bfloat16, + ... ) + + >>> # 2.1 + >>> # controlnet = ZImageControlNetModel.from_single_file( + >>> # hf_hub_download( + >>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + >>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", + >>> # ), + >>> # torch_dtype=torch.bfloat16, + >>> # ) + + >>> # 2.0 + >>> # controlnet = ZImageControlNetModel.from_single_file( + >>> # hf_hub_download( + >>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + >>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors", + >>> # ), + >>> # torch_dtype=torch.bfloat16, + >>> # ) + + >>> pipe = ZImageControlNetPipeline.from_pretrained( + ... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> control_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true" + ... ) + >>> prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。" + >>> image = pipe( + ... prompt, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.75, + ... height=1728, + ... width=992, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(43), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.75, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + if num_channels_latents != self.controlnet.config.control_in_dim: + # For model version 2.0 + control_image = torch.cat( + [ + control_image, + torch.zeros( + control_image.shape[0], + self.controlnet.config.control_in_dim - num_channels_latents, + *control_image.shape[2:], + ).to(device=control_image.device, dtype=control_image.dtype), + ], + dim=1, + ) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + control_image, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py new file mode 100644 index 000000000000..3b0f8dc288d3 --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -0,0 +1,746 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageControlNetInpaintPipeline + >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image + >>> from huggingface_hub import hf_hub_download + + >>> controlnet = ZImageControlNetModel.from_single_file( + ... hf_hub_download( + ... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + ... filename="Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors", + ... ), + ... torch_dtype=torch.bfloat16, + ... ) + + >>> # 2.0 + >>> # controlnet = ZImageControlNetModel.from_single_file( + >>> # hf_hub_download( + >>> # "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + >>> # filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors", + >>> # ), + >>> # torch_dtype=torch.bfloat16, + >>> # ) + + >>> pipe = ZImageControlNetInpaintPipeline.from_pretrained( + ... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/inpaint.jpg?download=true" + ... ) + >>> mask_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/mask.jpg?download=true" + ... ) + >>> control_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/pose.jpg?download=true" + ... ) + >>> prompt = "一位年轻女子站在阳光明媚的海岸线上,画面为全身竖构图,身体微微侧向右侧,左手自然下垂,右臂弯曲扶在腰间,她的手指清晰可见,站姿放松而略带羞涩。她身穿轻盈的白色连衣裙,裙摆在海风中轻轻飘动,布料半透、质感柔软。女子拥有一头鲜艳的及腰紫色长发,被海风吹起,在身侧轻盈飞舞,发间系着一个精致的黑色蝴蝶结,与发色形成对比。她面容清秀,眉目精致,肤色白皙细腻,表情温柔略显羞涩,微微低头,眼神静静望向远处的海平线,流露出甜美的青春气息与若有所思的神情。背景是辽阔无垠的海洋与蔚蓝天空,阳光从侧前方洒下,海面波光粼粼,泛着温暖的金色光晕,天空清澈明亮,云朵稀薄,整体色调清新唯美。" + >>> image = pipe( + ... prompt, + ... image=image, + ... mask_image=mask_image, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.75, + ... height=1728, + ... width=992, + ... num_inference_steps=25, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(43), + ... ).images[0] + >>> image.save("zimage-inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + if transformer.in_channels == controlnet.config.control_in_dim: + raise ValueError( + "ZImageControlNetInpaintPipeline is not compatible with `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union`, use `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0`." + ) + controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.75, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to( + device=control_image.device, dtype=control_image.dtype + ) + + init_image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = init_image.shape[-2:] + init_image = init_image * (mask_condition < 0.5) + init_image = retrieve_latents(self.vae.encode(init_image), generator=generator, sample_mode="argmax") + init_image = (init_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + init_image = init_image.unsqueeze(2) + + mask_condition = F.interpolate(1 - mask_condition[:, :1], size=init_image.size()[-2:], mode="nearest").to( + device=control_image.device, dtype=control_image.dtype + ) + mask_condition = mask_condition.unsqueeze(2) + + control_image = torch.cat([control_image, mask_condition, init_image], dim=1) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + control_image, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py new file mode 100644 index 000000000000..2b3e80a2082b --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py @@ -0,0 +1,709 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = ZImageImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> init_image = load_image(url).resize((1024, 1024)) + + >>> prompt = "A fantasy landscape with mountains and a river, detailed, vibrant colors" + >>> image = pipe( + ... prompt, + ... image=init_image, + ... strength=0.6, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage_img2img.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + r""" + The ZImage pipeline for image-to-image generation. + + Args: + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`PreTrainedModel`]): + A text encoder model to encode text prompts. + tokenizer ([`AutoTokenizer`]): + A tokenizer to tokenize text prompts. + transformer ([`ZImageTransformer2DModel`]): + A ZImage transformer model to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + # Encode the input image + image = image.to(device=device, dtype=dtype) + if image.shape[1] != num_channels_latents: + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + + # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor) + image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + image_latents = image + + # Handle batch size expansion + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + # Add noise using flow matching scale_noise + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + strength: float = 0.6, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for image-to-image generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a + list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or + a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`. + strength (`float`, *optional*, defaults to 0.6): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. If not provided, uses the input image height. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. If not provided, uses the input image width. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + # 1. Check inputs and validate strength + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") + + # 2. Preprocess image + init_image = self.image_processor.preprocess(image) + init_image = init_image.to(dtype=torch.float32) + + # Get dimensions from the preprocessed image if not specified + if height is None: + height = init_image.shape[-2] + if width is None: + width = init_image.shape[-1] + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + + # Calculate latent dimensions for image_seq_len + latent_height = 2 * (int(height) // (self.vae_scale_factor * 2)) + latent_width = 2 * (int(width) // (self.vae_scale_factor * 2)) + image_seq_len = (latent_height // 2) * (latent_width // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + + # 6. Adjust timesteps based on strength + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline " + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(actual_batch_size) + + # 7. Prepare latents from image + latents = self.prepare_latents( + init_image, + latent_timestep, + actual_batch_size, + num_channels_latents, + height, + width, + prompt_embeds[0].dtype, + device, + generator, + latents, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 8. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py new file mode 100644 index 000000000000..26848bea0a9e --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_omni.py @@ -0,0 +1,742 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +from transformers import AutoTokenizer, PreTrainedModel, Siglip2ImageProcessorFast, Siglip2VisionModel + +from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin +from ...models.autoencoders import AutoencoderKL +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..flux2.image_processor import Flux2ImageProcessor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageOmniPipeline + + >>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + siglip: Siglip2VisionModel, + siglip_processor: Siglip2ImageProcessorFast, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + siglip=siglip, + siglip_processor=siglip_processor, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = self.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = ( + self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + # image_latents = [image_latents] * batch_size + image_latents = [image_latents.copy() for _ in range(batch_size)] + + return image_latents + + def prepare_siglip_embeds( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = self.siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if image is not None and not isinstance(image, list): + image = [image] + num_condition_images = len(image) if image is not None else 0 + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + # 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2 + condition_images = [] + resized_images = [] + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + if height is not None and width is not None: + img = self.image_processor._resize_to_target_area(img, height * width) + else: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + if len(condition_images) > 0: + height = height or image_height + width = width or image_width + + else: + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + condition_latents = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents] + if self.do_classifier_free_guidance: + negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents] + + condition_siglip_embeds = self.prepare_siglip_embeds( + images=resized_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds] + if self.do_classifier_free_guidance: + negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds] + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] + negative_condition_siglip_embeds = [ + None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds + ] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + condition_latents_model_input = condition_latents + negative_condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + condition_latents_model_input = condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + # Combine condition latents with target latent + current_batch_size = len(latent_model_input_list) + x_combined = [ + condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size) + ] + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size) + ] + + model_out_list = self.transformer( + x=x_combined, + t=timestep_model_input, + cap_feats=prompt_embeds_model_input, + siglip_feats=condition_siglip_embeds_model_input, + image_noise_mask=image_noise_mask, + return_dict=False, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/src/diffusers/quantizers/modelopt/modelopt_quantizer.py b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py index 534f752321b3..7312036f52d0 100644 --- a/src/diffusers/quantizers/modelopt/modelopt_quantizer.py +++ b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py @@ -27,7 +27,7 @@ class NVIDIAModelOptQuantizer(DiffusersQuantizer): r""" - Diffusers Quantizer for TensorRT Model Optimizer + Diffusers Quantizer for Nvidia-Model Optimizer """ use_keep_in_fp32_modules = True diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 5dd8f56717df..c905a928c79d 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin): - Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row`, - - **Floating point X-bit quantization:** + - **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0) - Full function names: `fpx_weight_only` - Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must @@ -531,12 +531,18 @@ def post_init(self): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): - is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") - if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): + is_floatx_quant_type = self.quant_type.startswith("fp") + is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type + if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): raise ValueError( f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." ) + elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"): + raise ValueError( + f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. " + f"Please downgrade to torchao <= 0.14.1 to use this quantization type." + ) raise ValueError( f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " @@ -617,12 +623,11 @@ def _get_torchao_quant_type_to_method(cls): """ if is_torchao_available(): - # TODO(aryan): Support autoquant and sparsify + # TODO(aryan): Support sparsify from torchao.quantization import ( float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, - fpx_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -630,6 +635,8 @@ def _get_torchao_quant_type_to_method(cls): uintx_weight_only, ) + if is_torchao_version("<=", "0.14.1"): + from torchao.quantization import fpx_weight_only # TODO(aryan): Add a note on how to use PerAxis and PerGroup observers from torchao.quantization.observer import PerRow, PerTensor @@ -650,18 +657,21 @@ def generate_float8dq_types(dtype: torch.dtype): return types def generate_fpx_quantization_types(bits: int): - types = {} + if is_torchao_version("<=", "0.14.1"): + types = {} - for ebits in range(1, bits): - mbits = bits - ebits - 1 - types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) + for ebits in range(1, bits): + mbits = bits - ebits - 1 + types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits) - non_sign_bits = bits - 1 - default_ebits = (non_sign_bits + 1) // 2 - default_mbits = non_sign_bits - default_ebits - types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) + non_sign_bits = bits - 1 + default_ebits = (non_sign_bits + 1) // 2 + default_mbits = non_sign_bits - default_ebits + types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits) - return types + return types + else: + raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0") INT4_QUANTIZATION_TYPES = { # int4 weight + bfloat16/float16 activation @@ -710,15 +720,15 @@ def generate_fpx_quantization_types(bits: int): **generate_float8dq_types(torch.float8_e4m3fn), # float8 weight + float8 activation (static) "float8_static_activation_float8_weight": float8_static_activation_float8_weight, - # For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly - # fpx weight + bfloat16/float16 activation - **generate_fpx_quantization_types(3), - **generate_fpx_quantization_types(4), - **generate_fpx_quantization_types(5), - **generate_fpx_quantization_types(6), - **generate_fpx_quantization_types(7), } + if is_torchao_version("<=", "0.14.1"): + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6)) + FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7)) + UINTX_QUANTIZATION_DTYPES = { "uintx_weight_only": uintx_weight_only, "uint1wo": partial(uintx_weight_only, dtype=torch.uint1), diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 2334c7af8630..11435b85eb4d 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -36,6 +36,9 @@ from ..base import DiffusersQuantizer +logger = logging.get_logger(__name__) + + if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin @@ -83,11 +86,19 @@ def _update_torch_safe_globals(): ] try: from torchao.dtypes import NF4Tensor - from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl - from torchao.dtypes.uintx.uint4_layout import UInt4Tensor from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor - safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) + safe_globals.extend([UintxTensor, UintxAQTTensorImpl, NF4Tensor]) + + # note: is_torchao_version(">=", "0.16.0") does not work correctly + # with torchao nightly, so using a ">" check which does work correctly + if is_torchao_version(">", "0.15.0"): + pass + else: + from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl + from torchao.dtypes.uintx.uint4_layout import UInt4Tensor + + safe_globals.extend([UInt4Tensor, Float8AQTTensorImpl]) except (ImportError, ModuleNotFoundError) as e: logger.warning( @@ -123,9 +134,6 @@ def fuzzy_match_size(config_name: str) -> Optional[str]: return None -logger = logging.get_logger(__name__) - - def _quantization_type(weight): from torchao.dtypes import AffineQuantizedTensor from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor @@ -336,7 +344,6 @@ def get_cuda_warm_up_factor(self): from torchao.core.config import AOBaseConfig quant_type = self.quantization_config.quant_type - # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype if isinstance(quant_type, AOBaseConfig): # Extract size digit using fuzzy match on the class name config_name = quant_type.__class__.__name__ diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 29052c1ba0cb..4199e75bf331 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -66,6 +66,7 @@ _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"] _import_structure["scheduling_lcm"] = ["LCMScheduler"] + _import_structure["scheduling_ltx_euler_ancestral_rf"] = ["LTXEulerAncestralRFScheduler"] _import_structure["scheduling_pndm"] = ["PNDMScheduler"] _import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] @@ -168,6 +169,7 @@ from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler from .scheduling_lcm import LCMScheduler + from .scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler from .scheduling_sasolver import SASolverScheduler diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py index 767fa9157f59..23c0e138c4ce 100644 --- a/src/diffusers/schedulers/scheduling_consistency_decoder.py +++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py @@ -14,7 +14,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -28,8 +28,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -40,6 +40,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -71,6 +78,22 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput): class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin): + """ + A scheduler for the consistency decoder used in Stable Diffusion pipelines. + + This scheduler implements a two-step denoising process using consistency models for decoding latent representations + into images. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, *optional*, defaults to `1024`): + The number of diffusion steps to train the model. + sigma_data (`float`, *optional*, defaults to `0.5`): + The standard deviation of the data distribution. Used for computing the skip and output scaling factors. + """ + order = 1 @register_to_config @@ -78,7 +101,7 @@ def __init__( self, num_train_timesteps: int = 1024, sigma_data: float = 0.5, - ): + ) -> None: betas = betas_for_alpha_bar(num_train_timesteps) alphas = 1.0 - betas @@ -98,8 +121,18 @@ def __init__( def set_timesteps( self, num_inference_steps: Optional[int] = None, - device: Union[str, torch.device] = None, - ): + device: Optional[Union[str, torch.device]] = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. Currently, only + `2` inference steps are supported. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ if num_inference_steps != 2: raise ValueError("Currently more than 2 inference steps are not supported.") @@ -111,7 +144,15 @@ def set_timesteps( self.c_in = self.c_in.to(device) @property - def init_noise_sigma(self): + def init_noise_sigma(self) -> torch.Tensor: + """ + Return the standard deviation of the initial noise distribution. + + Returns: + `torch.Tensor`: + The initial noise sigma value from the precomputed `sqrt_one_minus_alphas_cumprod` at the first + timestep. + """ return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]] def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: @@ -146,20 +187,20 @@ def step( Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. - timestep (`float`): + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): - A random number generator. + A random number generator for reproducibility. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a - [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`. + [`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`. Returns: - [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`: - If return_dict is `True`, - [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise + [`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`: + If `return_dict` is `True`, + [`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 386a43db0f9c..195ff81b4c91 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -83,7 +83,7 @@ def __init__( s_noise: float = 1.0, rho: float = 7.0, clip_denoised: bool = True, - ): + ) -> None: # standard deviation of the initial noise distribution self.init_noise_sigma = sigma_max @@ -102,21 +102,29 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index, or `None` if not yet initialized. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index, or `None` if not yet set. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -151,7 +159,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T self.is_scale_input_called = True return sample - def sigma_to_t(self, sigmas: Union[float, np.ndarray]): + def sigma_to_t(self, sigmas: Union[float, np.ndarray]) -> np.ndarray: """ Gets scaled timesteps from the Karras sigmas for input to the consistency model. @@ -160,8 +168,8 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]): A single Karras sigma or an array of Karras sigmas. Returns: - `float` or `np.ndarray`: - A scaled input timestep or scaled input timestep array. + `np.ndarray`: + A scaled input timestep array. """ if not isinstance(sigmas, np.ndarray): sigmas = np.array(sigmas, dtype=np.float64) @@ -173,14 +181,14 @@ def sigma_to_t(self, sigmas: Union[float, np.ndarray]): def set_timesteps( self, num_inference_steps: Optional[int] = None, - device: Union[str, torch.device] = None, + device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, - ): + ) -> None: """ Sets the timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. @@ -244,9 +252,19 @@ def set_timesteps( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Modified _convert_to_karras implementation that takes in ramp as argument - def _convert_to_karras(self, ramp): - """Constructs the noise schedule of Karras et al. (2022).""" + def _convert_to_karras(self, ramp: np.ndarray) -> np.ndarray: + """ + Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative + Models](https://huggingface.co/papers/2206.00364). + Args: + ramp (`np.ndarray`): + A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max. + + Returns: + `np.ndarray`: + The Karras sigma schedule array. + """ sigma_min: float = self.config.sigma_min sigma_max: float = self.config.sigma_max @@ -256,14 +274,25 @@ def _convert_to_karras(self, ramp): sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def get_scalings(self, sigma): + def get_scalings(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the scaling factors for the consistency model output. + + Args: + sigma (`torch.Tensor`): + The current sigma value in the noise schedule. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output). + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out - def get_scalings_for_boundary_condition(self, sigma): + def get_scalings_for_boundary_condition(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Gets the scalings used in the consistency model parameterization (from Appendix C of the [paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition. @@ -275,7 +304,7 @@ def get_scalings_for_boundary_condition(self, sigma): The current sigma in the Karras sigma schedule. Returns: - `tuple`: + `Tuple[torch.Tensor, torch.Tensor]`: A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out` (which weights the consistency model output) is the second element. """ @@ -348,13 +377,13 @@ def step( Args: model_output (`torch.Tensor`): The direct output from the learned diffusion model. - timestep (`float`): + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. - return_dict (`bool`, *optional*, defaults to `True`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`. @@ -406,7 +435,10 @@ def step( # Noise is not used for onestep sampling. if len(self.timesteps) > 1: noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + model_output.shape, + dtype=model_output.dtype, + device=model_output.device, + generator=generator, ) else: noise = torch.zeros_like(model_output) @@ -475,5 +507,12 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples - def __len__(self): + def __len__(self) -> int: + """ + Returns the number of training timesteps. + + Returns: + `int`: + The number of training timesteps configured for the scheduler. + """ return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py index 103cca81c6a5..9c6b0fcf69b6 100644 --- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm import math -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import numpy as np import torch @@ -36,27 +36,30 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - sigma_min (`float`, *optional*, defaults to 0.3): + sigma_min (`float`, defaults to `0.3`): Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. - sigma_max (`float`, *optional*, defaults to 500): + sigma_max (`float`, defaults to `500`): Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. - sigma_data (`float`, *optional*, defaults to 1.0): + sigma_data (`float`, defaults to `1.0`): The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. - sigma_schedule (`str`, *optional*, defaults to `exponential`): - Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper - (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential - schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl. - num_train_timesteps (`int`, defaults to 1000): + sigma_schedule (`str`, defaults to `"exponential"`): + Sigma schedule to compute the `sigmas`. Must be one of `"exponential"` or `"karras"`. The exponential + schedule was incorporated in [stabilityai/cosxl](https://huggingface.co/stabilityai/cosxl). The Karras + schedule is introduced in the [EDM](https://huggingface.co/papers/2206.00364) paper. + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - solver_order (`int`, defaults to 2): + solver_order (`int`, defaults to `2`): The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. - prediction_type (`str`, defaults to `v_prediction`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + prediction_type (`str`, defaults to `"v_prediction"`): + Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion + process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper). - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + rho (`float`, defaults to `7.0`): + The parameter for calculating the Karras sigma schedule from the EDM + [paper](https://huggingface.co/papers/2206.00364). + solver_type (`str`, defaults to `"midpoint"`): + Solver type for the second-order solver. Must be one of `"midpoint"` or `"heun"`. The solver type slightly + affects the sample quality, especially for a small number of steps. It is recommended to use `"midpoint"`. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. @@ -65,8 +68,9 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference steps, but sometimes may result in blurring. final_sigmas_type (`str`, defaults to `"zero"`): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + The final `sigma` value for the noise schedule during the sampling process. Must be one of `"zero"` or + `"sigma_min"`. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If + `"zero"`, the final sigma is set to 0. """ _compatibles = [] @@ -78,16 +82,16 @@ def __init__( sigma_min: float = 0.3, sigma_max: float = 500, sigma_data: float = 1.0, - sigma_schedule: str = "exponential", + sigma_schedule: Literal["exponential", "karras"] = "exponential", num_train_timesteps: int = 1000, solver_order: int = 2, - prediction_type: str = "v_prediction", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "v_prediction", rho: float = 7.0, - solver_type: str = "midpoint", + solver_type: Literal["midpoint", "heun"] = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" - ): + final_sigmas_type: Literal["zero", "sigma_min"] = "zero", + ) -> None: if solver_type not in ["midpoint", "heun"]: if solver_type in ["logrho", "bh1", "bh2"]: self.register_to_config(solver_type="midpoint") @@ -113,26 +117,40 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def init_noise_sigma(self): - # standard deviation of the initial noise distribution + def init_noise_sigma(self) -> float: + """ + The standard deviation of the initial noise distribution. + + Returns: + `float`: + The initial noise sigma value computed as `sqrt(sigma_max^2 + 1)`. + """ return (self.config.sigma_max**2 + 1) ** 0.5 @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index, or `None` if not yet initialized. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index, or `None` if not yet set. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -143,19 +161,63 @@ def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs - def precondition_inputs(self, sample, sigma): + def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the input sample by scaling it according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor to precondition. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The scaled input sample. + """ c_in = self._get_conditioning_c_in(sigma) scaled_sample = sample * c_in return scaled_sample - def precondition_noise(self, sigma): + def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the noise level by computing a normalized timestep representation. + + Args: + sigma (`float` or `torch.Tensor`): + The sigma (noise level) value to precondition. + + Returns: + `torch.Tensor`: + The preconditioned noise value computed as `atan(sigma) / pi * 2`. + """ if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) return sigma.atan() / math.pi * 2 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs - def precondition_outputs(self, sample, model_output, sigma): + def precondition_outputs( + self, + sample: torch.Tensor, + model_output: torch.Tensor, + sigma: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Precondition the model outputs according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor. + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The denoised sample computed by combining the skip connection and output scaling. + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) @@ -173,13 +235,13 @@ def precondition_outputs(self, sample, model_output, sigma): # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that + need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): - The input sample. - timestep (`int`, *optional*): + The input sample tensor. + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: @@ -195,12 +257,14 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T self.is_scale_input_called = True return sample - def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + def set_timesteps( + self, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. @@ -242,8 +306,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self.noise_sampler = None # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas - def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + def _compute_karras_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364). + + Args: + ramp (`torch.Tensor`): + A tensor of values in [0, 1] representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed Karras sigma schedule. + """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -254,10 +337,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch. return sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas - def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Implementation closely follows k-diffusion. - + def _compute_exponential_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Compute the exponential sigma schedule. Implementation closely follows k-diffusion: https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + + Args: + ramp (`torch.Tensor`): + A tensor of values representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed exponential sigma schedule. """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -265,7 +365,7 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -301,7 +401,19 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert sigma to alpha and sigma_t values for the diffusion process. + + Args: + sigma (`torch.Tensor`): + The sigma (noise level) value. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing `alpha_t` (always 1 since inputs are pre-scaled) and `sigma_t` (same as input + sigma). + """ alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 sigma_t = sigma @@ -354,7 +466,10 @@ def dpm_solver_first_order_update( `torch.Tensor`: The sample tensor at the previous timestep. """ - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -464,7 +579,7 @@ def index_for_timestep( return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. @@ -485,7 +600,7 @@ def step( model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -495,20 +610,19 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`int`): + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. - """ if self.num_inference_steps is None: raise ValueError( @@ -540,7 +654,10 @@ def step( [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() ) self.noise_sampler = BrownianTreeNoiseSampler( - model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + model_output, + sigma_min=self.config.sigma_min, + sigma_max=self.config.sigma_max, + seed=seed, ) noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( model_output.device @@ -612,9 +729,27 @@ def add_noise( return noisy_samples # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in - def _get_conditioning_c_in(self, sigma): + def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Compute the input conditioning factor for the EDM formulation. + + Args: + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `float` or `torch.Tensor`: + The input conditioning factor `c_in`. + """ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) return c_in - def __len__(self): + def __len__(self) -> int: + """ + Returns the number of training timesteps. + + Returns: + `int`: + The number of training timesteps configured for the scheduler. + """ return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index d7fe29a72ac9..92c3e20013dd 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -77,6 +77,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index f2683d1304ec..1a77a652786d 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -77,6 +77,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -93,14 +100,13 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) -def rescale_zero_terminal_snr(alphas_cumprod): +def rescale_zero_terminal_snr(alphas_cumprod: torch.Tensor) -> torch.Tensor: """ - Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - + Rescales betas to have zero terminal SNR Based on (Algorithm 1)[https://huggingface.co/papers/2305.08891] Args: - betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. + alphas_cumprod (`torch.Tensor`): + The alphas cumulative products that the scheduler is being initialized with. Returns: `torch.Tensor`: rescaled betas with zero terminal SNR @@ -135,11 +141,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to 0.00085): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to 0.0120): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`str`, defaults to `"scaled_linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): @@ -172,6 +178,8 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + snr_shift_scale (`float`, defaults to 3.0): + Shift scale for SNR. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -183,15 +191,15 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.0120, - beta_schedule: str = "scaled_linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, snr_shift_scale: float = 3.0, ): @@ -201,7 +209,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float64, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -231,7 +247,7 @@ def __init__( self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -258,7 +274,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -310,7 +330,7 @@ def step( sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: @@ -321,7 +341,7 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. @@ -480,5 +500,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 802d8f79779d..476f741bcdde 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -32,6 +33,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DDIMSchedulerState: common: CommonSchedulerState @@ -125,6 +129,10 @@ def __init__( prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: @@ -152,7 +160,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSch ) def scale_model_input( - self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DDIMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Args: @@ -190,7 +201,9 @@ def set_timesteps( def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + prev_timestep >= 0, + state.common.alphas_cumprod[prev_timestep], + state.final_alpha_cumprod, ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 8ae13ad49d10..a3c9ed1f6258 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -49,7 +49,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -63,8 +63,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -75,6 +75,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -92,7 +99,7 @@ def alpha_bar_fn(t): # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -180,14 +187,14 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, **kwargs, ): @@ -203,7 +210,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -249,7 +264,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -301,20 +320,10 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + timestep (`int`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - eta (`float`): - The weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`, defaults to `False`): - If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary - because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no - clipping has happened, "corrected" `model_output` would coincide with the one provided as input and - `use_clipped_model_output` has no effect. - variance_noise (`torch.Tensor`): - Alternative to generating noise with `generator` by directly providing the noise for the variance - itself. Useful for methods such as [`CycleDiffusion`]. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`. @@ -328,7 +337,8 @@ def step( # 1. get previous step value (=t+1) prev_timestep = timestep timestep = min( - timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1 + timestep - self.config.num_train_timesteps // self.num_inference_steps, + self.config.num_train_timesteps - 1, ) # 2. compute alphas, betas @@ -371,5 +381,5 @@ def step( return (prev_sample, pred_original_sample) return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 10873a082fee..76f0636fbf6c 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -51,7 +51,7 @@ class DDIMParallelSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -77,6 +77,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -94,7 +101,7 @@ def alpha_bar_fn(t): # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -259,7 +266,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def _get_variance(self, timestep, prev_timestep=None): + def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor: if prev_timestep is None: prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps @@ -272,7 +279,7 @@ def _get_variance(self, timestep, prev_timestep=None): return variance - def _batch_get_variance(self, t, prev_t): + def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor: alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)] alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0) @@ -328,7 +335,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -385,7 +392,7 @@ def step( sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[DDIMParallelSchedulerOutput, Tuple]: @@ -399,11 +406,13 @@ def step( sample (`torch.Tensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped - predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when - `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would - coincide with the one provided as input and `use_clipped_model_output` will have not effect. - generator: random number generator. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This + correction is necessary because the predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches + the input and `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + Random number generator. variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://huggingface.co/papers/2210.05559) @@ -489,7 +498,10 @@ def step( if variance_noise is None: variance_noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) variance = std_dev_t * variance_noise @@ -506,7 +518,7 @@ def step( def batch_step_no_noise( self, model_output: torch.Tensor, - timesteps: List[int], + timesteps: torch.Tensor, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, @@ -521,7 +533,7 @@ def batch_step_no_noise( Args: model_output (`torch.Tensor`): direct output from learned diffusion model. - timesteps (`List[int]`): + timesteps (`torch.Tensor`): current discrete timesteps in the diffusion chain. This is now a list of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. @@ -689,5 +701,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ded88b8e1e0a..2e2816bbf3ba 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -48,7 +48,7 @@ class DDPMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -62,8 +62,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -74,6 +74,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -185,7 +192,12 @@ def __init__( beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: Literal[ - "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", ] = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", @@ -203,10 +215,20 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "laplace": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace") elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) @@ -259,7 +281,7 @@ def set_timesteps( Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): @@ -328,7 +350,14 @@ def _get_variance( t: int, predicted_variance: Optional[torch.Tensor] = None, variance_type: Optional[ - Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + Literal[ + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", + ] ] = None, ) -> torch.Tensor: """ @@ -463,7 +492,10 @@ def step( prev_t = self.previous_timestep(t) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None @@ -512,7 +544,10 @@ def step( if t > 0: device = model_output.device variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype, ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise @@ -611,7 +646,7 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor def __len__(self) -> int: return self.config.num_train_timesteps - def previous_timestep(self, timestep: int) -> int: + def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: """ Compute the previous timestep in the diffusion chain. @@ -620,7 +655,7 @@ def previous_timestep(self, timestep: int) -> int: The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index a3264f54f572..e02b7ea0c0f3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -32,6 +33,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DDPMSchedulerState: common: CommonSchedulerState @@ -42,7 +46,12 @@ class DDPMSchedulerState: num_inference_steps: Optional[int] = None @classmethod - def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray): + def create( + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + ): return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps) @@ -105,6 +114,10 @@ def __init__( prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState: @@ -123,7 +136,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSch ) def scale_model_input( - self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DDPMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Args: diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index 941fc16be080..b02c5376f2c6 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -50,7 +50,7 @@ class DDPMParallelSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -64,8 +64,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -76,6 +76,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -142,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): For more details, see the original paper: https://huggingface.co/papers/2006.11239 Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - variance_type (`str`): - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + trained_betas (`np.ndarray`, *optional*): + Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`, defaults to `"fixed_small"`): + Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - clip_sample (`bool`, default `True`): - option to clip predicted sample for numerical stability. - clip_sample_range (`float`, default `1.0`): - the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. - prediction_type (`str`, default `epsilon`, optional): - prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + clip_sample (`bool`, defaults to `True`): + Option to clip predicted sample for numerical stability. + prediction_type (`str`, defaults to `"epsilon"`): + Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://huggingface.co/papers/2210.02303) - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method (introduced by Imagen, https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`. - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid only when `thresholding=True`. - timestep_spacing (`str`, default `"leading"`): + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, default `0`): + steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and @@ -195,7 +205,12 @@ def __init__( beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: Literal[ - "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range" + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", ] = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", @@ -213,10 +228,20 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "laplace": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace") elif beta_schedule == "sigmoid": # GeoDiff sigmoid schedule betas = torch.linspace(-6, 6, num_train_timesteps) @@ -271,7 +296,7 @@ def set_timesteps( Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): @@ -341,7 +366,14 @@ def _get_variance( t: int, predicted_variance: Optional[torch.Tensor] = None, variance_type: Optional[ - Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"] + Literal[ + "fixed_small", + "fixed_small_log", + "fixed_large", + "fixed_large_log", + "learned", + "learned_range", + ] ] = None, ) -> torch.Tensor: """ @@ -449,7 +481,7 @@ def step( model_output: torch.Tensor, timestep: int, sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[DDPMParallelSchedulerOutput, Tuple]: """ @@ -461,7 +493,8 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.Tensor`): current instance of sample being created by diffusion process. - generator: random number generator. + generator (`torch.Generator`, *optional*): + Random number generator. return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class Returns: @@ -474,7 +507,10 @@ def step( prev_t = self.previous_timestep(t) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None @@ -523,7 +559,10 @@ def step( if t > 0: device = model_output.device variance_noise = randn_tensor( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=device, + dtype=model_output.dtype, ) if self.variance_type == "fixed_small_log": variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise @@ -546,7 +585,7 @@ def step( def batch_step_no_noise( self, model_output: torch.Tensor, - timesteps: List[int], + timesteps: torch.Tensor, sample: torch.Tensor, ) -> torch.Tensor: """ @@ -559,8 +598,8 @@ def batch_step_no_noise( Args: model_output (`torch.Tensor`): direct output from learned diffusion model. - timesteps (`List[int]`): - current discrete timesteps in the diffusion chain. This is now a list of integers. + timesteps (`torch.Tensor`): + Current discrete timesteps in the diffusion chain. This is a tensor of integers. sample (`torch.Tensor`): current instance of sample being created by diffusion process. @@ -574,7 +613,10 @@ def batch_step_no_noise( t = t.view(-1, *([1] * (model_output.ndim - 1))) prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1))) - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: pass @@ -705,7 +747,7 @@ def __len__(self): return self.config.num_train_timesteps # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep - def previous_timestep(self, timestep): + def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]: """ Compute the previous timestep in the diffusion chain. @@ -714,7 +756,7 @@ def previous_timestep(self, timestep): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 09ce338a9222..7c2dfd8e503f 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -34,7 +34,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -60,6 +60,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -84,33 +91,35 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): + trained_betas (`np.ndarray` or `List[float]`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, defaults to 2): + solver_order (`int`, defaults to `2`): The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`): + prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://huggingface.co/papers/2210.02303) paper). + `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): + dynamic_thresholding_ratio (`float`, defaults to `0.995`): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): + sample_max_value (`float`, defaults to `1.0`): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. - algorithm_type (`str`, defaults to `deis`): + algorithm_type (`"deis"`, defaults to `"deis"`): The algorithm type for the solver. + solver_type (`"logrho"`, defaults to `"logrho"`): + Solver type for DEIS. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. use_karras_sigmas (`bool`, *optional*, defaults to `False`): @@ -121,11 +130,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): use_beta_sigmas (`bool`, *optional*, defaults to `False`): Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. - timestep_spacing (`str`, defaults to `"linspace"`): + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to `1.0`): + The flow shift parameter for flow-based models. + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): + steps_offset (`int`, defaults to `0`): An offset added to the inference steps, as required by some model families. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to use dynamic shifting for the noise schedule. + time_shift_type (`"exponential"`, defaults to `"exponential"`): + The type of time shifting to apply. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -137,29 +154,38 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - algorithm_type: str = "deis", - solver_type: str = "logrho", + algorithm_type: Literal["deis"] = "deis", + solver_type: Literal["logrho"] = "logrho", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, - timestep_spacing: str = "linspace", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, use_dynamic_shifting: bool = False, - time_shift_type: str = "exponential", - ): + time_shift_type: Literal["exponential"] = "exponential", + ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): raise ValueError( "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." ) @@ -169,7 +195,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -211,21 +245,21 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -236,8 +270,11 @@ def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index def set_timesteps( - self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None - ): + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + mu: Optional[float] = None, + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -246,6 +283,9 @@ def set_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + mu (`float`, *optional*): + The mu parameter for dynamic shifting. Only used when `use_dynamic_shifting=True` and + `time_shift_type="exponential"`. """ if mu is not None: assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" @@ -363,7 +403,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -400,7 +440,7 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert sigma values to alpha_t and sigma_t values. @@ -422,7 +462,7 @@ def _sigma_to_alpha_sigma_t(self, sigma): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: """ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). @@ -648,7 +688,10 @@ def deis_first_order_update( "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", ) - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -714,7 +757,11 @@ def multistep_deis_second_order_update( m0, m1 = model_output_list[-1], model_output_list[-2] - rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 + rho_t, rho_s0, rho_s1 = ( + sigma_t / alpha_t, + sigma_s0 / alpha_s0, + sigma_s1 / alpha_s1, + ) if self.config.algorithm_type == "deis": @@ -854,7 +901,7 @@ def index_for_timestep( return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. @@ -884,18 +931,17 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`int`): + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. - """ if self.num_inference_steps is None: raise ValueError( @@ -1000,5 +1046,5 @@ def add_noise( noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index 0a9082208cf4..0c576467a19a 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -52,7 +52,7 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -78,6 +78,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -98,7 +105,6 @@ def rescale_zero_terminal_snr(alphas_cumprod): """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) - Args: betas (`torch.Tensor`): the betas that the scheduler is being initialized with. @@ -168,11 +174,14 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True`. timestep_spacing (`str`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. Choose from + `leading`, `linspace` or `trailing`. rescale_betas_zero_snr (`bool`, defaults to `False`): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + snr_shift_scale (`float`, defaults to 3.0): + Shift scale for SNR. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -184,15 +193,15 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.00085, beta_end: float = 0.0120, - beta_schedule: str = "scaled_linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon", clip_sample_range: float = 1.0, sample_max_value: float = 1.0, - timestep_spacing: str = "leading", + timestep_spacing: Literal["leading", "linspace", "trailing"] = "leading", rescale_betas_zero_snr: bool = False, snr_shift_scale: float = 3.0, ): @@ -202,7 +211,15 @@ def __init__( self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2 + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float64, + ) + ** 2 + ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule self.betas = betas_for_alpha_bar(num_train_timesteps) @@ -259,13 +276,20 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None """ return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Optional[Union[str, torch.device]] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None` (the default), the timesteps are not + moved. """ if num_inference_steps > self.config.num_train_timesteps: @@ -304,7 +328,27 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device) - def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): + def get_variables( + self, + alpha_prod_t: torch.Tensor, + alpha_prod_t_prev: torch.Tensor, + alpha_prod_t_back: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]: + """ + Compute the variables used for DPM-Solver++ (2M) referencing the original implementation. + + Args: + alpha_prod_t (`torch.Tensor`): + The cumulative product of alphas at the current timestep. + alpha_prod_t_prev (`torch.Tensor`): + The cumulative product of alphas at the previous timestep. + alpha_prod_t_back (`torch.Tensor`, *optional*): + The cumulative product of alphas at the timestep before the previous timestep. + + Returns: + `tuple`: + A tuple containing the variables `h`, `r`, `lamb`, `lamb_next`. + """ lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log() lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log() h = lamb_next - lamb @@ -317,7 +361,36 @@ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None) else: return h, None, lamb, lamb_next - def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): + def get_mult( + self, + h: torch.Tensor, + r: Optional[torch.Tensor], + alpha_prod_t: torch.Tensor, + alpha_prod_t_prev: torch.Tensor, + alpha_prod_t_back: Optional[torch.Tensor] = None, + ) -> Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ]: + """ + Compute the multipliers for the previous sample and the predicted original sample. + + Args: + h (`torch.Tensor`): + The log-SNR difference. + r (`torch.Tensor`): + The ratio of log-SNR differences. + alpha_prod_t (`torch.Tensor`): + The cumulative product of alphas at the current timestep. + alpha_prod_t_prev (`torch.Tensor`): + The cumulative product of alphas at the previous timestep. + alpha_prod_t_back (`torch.Tensor`, *optional*): + The cumulative product of alphas at the timestep before the previous timestep. + + Returns: + `tuple`: + A tuple containing the multipliers. + """ mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp() mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5 @@ -331,13 +404,13 @@ def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): def step( self, model_output: torch.Tensor, - old_pred_original_sample: torch.Tensor, + old_pred_original_sample: Optional[torch.Tensor], timestep: int, timestep_back: int, sample: torch.Tensor, eta: float = 0.0, use_clipped_model_output: bool = False, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = False, ) -> Union[DDIMSchedulerOutput, Tuple]: @@ -348,8 +421,12 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`float`): + old_pred_original_sample (`torch.Tensor`): + The predicted original sample from the previous timestep. + timestep (`int`): The current discrete timestep in the diffusion chain. + timestep_back (`int`): + The timestep to look back to. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. eta (`float`): @@ -429,7 +506,12 @@ def step( return prev_sample, pred_original_sample else: denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample - noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype) + noise = randn_tensor( + sample.shape, + generator=generator, + device=sample.device, + dtype=sample.dtype, + ) x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise prev_sample = x_advanced @@ -517,5 +599,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index e7ba0ba1f30e..07cb64f32b58 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -34,7 +34,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -60,6 +60,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 71b9960bf2ff..66398073b29e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -31,6 +32,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class DPMSolverMultistepSchedulerState: common: CommonSchedulerState @@ -171,6 +175,10 @@ def __init__( timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolverMultistepSchedulerState: @@ -203,7 +211,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DPMSolv ) def set_timesteps( - self, state: DPMSolverMultistepSchedulerState, num_inference_steps: int, shape: Tuple + self, + state: DPMSolverMultistepSchedulerState, + num_inference_steps: int, + shape: Tuple, ) -> DPMSolverMultistepSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -301,10 +312,13 @@ def convert_model_output( if self.config.thresholding: # Dynamic thresholding in https://huggingface.co/papers/2205.11487 dynamic_max_val = jnp.percentile( - jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim)) + jnp.abs(x0_pred), + self.config.dynamic_thresholding_ratio, + axis=tuple(range(1, x0_pred.ndim)), ) dynamic_max_val = jnp.maximum( - dynamic_max_val, self.config.sample_max_value * jnp.ones_like(dynamic_max_val) + dynamic_max_val, + self.config.sample_max_value * jnp.ones_like(dynamic_max_val), ) x0_pred = jnp.clip(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val return x0_pred @@ -385,7 +399,11 @@ def multistep_dpm_solver_second_order_update( """ t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] - lambda_t, lambda_s0, lambda_s1 = state.lambda_t[t], state.lambda_t[s0], state.lambda_t[s1] + lambda_t, lambda_s0, lambda_s1 = ( + state.lambda_t[t], + state.lambda_t[s0], + state.lambda_t[s1], + ) alpha_t, alpha_s0 = state.alpha_t[t], state.alpha_t[s0] sigma_t, sigma_s0 = state.sigma_t[t], state.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 @@ -443,7 +461,12 @@ def multistep_dpm_solver_third_order_update( Returns: `jnp.ndarray`: the sample tensor at the previous timestep. """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + t, s0, s1, s2 = ( + prev_timestep, + timestep_list[-1], + timestep_list[-2], + timestep_list[-3], + ) m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( state.lambda_t[t], @@ -615,7 +638,10 @@ def step_3(state: DPMSolverMultistepSchedulerState) -> jnp.ndarray: return FlaxDPMSolverMultistepSchedulerOutput(prev_sample=prev_sample, state=state) def scale_model_input( - self, state: DPMSolverMultistepSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: DPMSolverMultistepSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 6696b0375f9f..2da90d287cf8 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -34,7 +34,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -60,6 +60,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 81c9e4134f57..6f905a623d70 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -117,7 +117,7 @@ def __call__(self, sigma, sigma_next): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -131,8 +131,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -143,6 +143,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 55c9fb6e7384..e9bf815aba86 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -36,7 +36,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -50,8 +50,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -62,6 +62,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -86,42 +93,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, *optional*): + trained_betas (`np.ndarray` or `List[float]`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, defaults to 2): + solver_order (`int`, defaults to `2`): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://huggingface.co/papers/2210.02303) paper). + `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): + dynamic_thresholding_ratio (`float`, defaults to `0.995`): The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): + sample_max_value (`float`, defaults to `1.0`): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver` + algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, or `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. - solver_type (`str`, defaults to `midpoint`): + solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`): Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): + lower_order_final (`bool`, defaults to `False`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. use_karras_sigmas (`bool`, *optional*, defaults to `False`): @@ -132,15 +139,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): use_beta_sigmas (`bool`, *optional*, defaults to `False`): Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. - final_sigmas_type (`str`, *optional*, defaults to `"zero"`): + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to `1.0`): + The flow shift parameter for flow-based models. + final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. + variance_type (`"learned"` or `"learned_range"`, *optional*): + Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's + output contains the predicted Gaussian variance. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to use dynamic shifting for the noise schedule. + time_shift_type (`"exponential"`, defaults to `"exponential"`): + The type of time shifting to apply. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -152,27 +167,27 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", + algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++", + solver_type: Literal["midpoint", "heun"] = "midpoint", lower_order_final: bool = False, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, + variance_type: Optional[Literal["learned", "learned_range"]] = None, use_dynamic_shifting: bool = False, - time_shift_type: str = "exponential", - ): + time_shift_type: Literal["exponential"] = "exponential", + ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: @@ -242,6 +257,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. + + Returns: + `List[int]`: + The list of solver orders for each timestep. """ steps = num_inference_steps order = self.config.solver_order @@ -276,21 +295,29 @@ def get_order_list(self, num_inference_steps: int) -> List[int]: return orders @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + + Returns: + `int` or `None`: + The begin index. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -302,19 +329,21 @@ def set_begin_index(self, begin_index: int = 0): def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None, timesteps: Optional[List[int]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + mu (`float`, *optional*): + The mu parameter for dynamic shifting. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is @@ -453,7 +482,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -490,7 +519,7 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert sigma values to alpha_t and sigma_t values. @@ -512,7 +541,7 @@ def _sigma_to_alpha_sigma_t(self, sigma): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: """ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). @@ -637,7 +666,7 @@ def convert_model_output( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ @@ -733,7 +762,7 @@ def dpm_solver_first_order_update( self, model_output: torch.Tensor, *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -797,7 +826,7 @@ def singlestep_dpm_solver_second_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -908,7 +937,7 @@ def singlestep_dpm_solver_third_order_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, + sample: Optional[torch.Tensor] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -1030,8 +1059,8 @@ def singlestep_dpm_solver_update( self, model_output_list: List[torch.Tensor], *args, - sample: torch.Tensor = None, - order: int = None, + sample: Optional[torch.Tensor] = None, + order: Optional[int] = None, noise: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: @@ -1125,7 +1154,7 @@ def index_for_timestep( return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. @@ -1146,7 +1175,7 @@ def step( model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -1156,11 +1185,13 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`int`): + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + generator (`torch.Generator`, *optional*): + A random number generator for stochastic sampling. + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: @@ -1277,5 +1308,5 @@ def add_noise( noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index d4e8ca5e8b18..a573f032cad8 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -175,13 +175,37 @@ def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs - def precondition_inputs(self, sample, sigma): + def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the input sample by scaling it according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor to precondition. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The scaled input sample. + """ c_in = self._get_conditioning_c_in(sigma) scaled_sample = sample * c_in return scaled_sample # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise - def precondition_noise(self, sigma): + def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the noise level by applying a logarithmic transformation. + + Args: + sigma (`float` or `torch.Tensor`): + The sigma (noise level) value to precondition. + + Returns: + `torch.Tensor`: + The preconditioned noise value computed as `0.25 * log(sigma)`. + """ if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) @@ -190,7 +214,27 @@ def precondition_noise(self, sigma): return c_noise # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs - def precondition_outputs(self, sample, model_output, sigma): + def precondition_outputs( + self, + sample: torch.Tensor, + model_output: torch.Tensor, + sigma: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Precondition the model outputs according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor. + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The denoised sample computed by combining the skip connection and output scaling. + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) @@ -208,13 +252,13 @@ def precondition_outputs(self, sample, model_output, sigma): # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that + need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): - The input sample. - timestep (`int`, *optional*): + The input sample tensor. + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: @@ -274,8 +318,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas - def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + def _compute_karras_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364). + + Args: + ramp (`torch.Tensor`): + A tensor of values in [0, 1] representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed Karras sigma schedule. + """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -286,10 +349,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch. return sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas - def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Implementation closely follows k-diffusion. - + def _compute_exponential_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Compute the exponential sigma schedule. Implementation closely follows k-diffusion: https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + + Args: + ramp (`torch.Tensor`): + A tensor of values representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed exponential sigma schedule. """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -433,7 +513,10 @@ def dpm_solver_first_order_update( `torch.Tensor`: The sample tensor at the previous timestep. """ - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -684,7 +767,10 @@ def step( if self.config.algorithm_type == "sde-dpmsolver++": noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) else: noise = None @@ -757,7 +843,18 @@ def add_noise( return noisy_samples # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in - def _get_conditioning_c_in(self, sigma): + def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Compute the input conditioning factor for the EDM formulation. + + Args: + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `float` or `torch.Tensor`: + The input conditioning factor `c_in`. + """ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) return c_in diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index 2ed05d396514..604d8b3ea6fa 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import torch @@ -57,29 +57,28 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - sigma_min (`float`, *optional*, defaults to 0.002): + sigma_min (`float`, *optional*, defaults to `0.002`): Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable range is [0, 10]. - sigma_max (`float`, *optional*, defaults to 80.0): + sigma_max (`float`, *optional*, defaults to `80.0`): Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable range is [0.2, 80.0]. - sigma_data (`float`, *optional*, defaults to 0.5): + sigma_data (`float`, *optional*, defaults to `0.5`): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. - sigma_schedule (`str`, *optional*, defaults to `karras`): - Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper - (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential - schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl. - num_train_timesteps (`int`, defaults to 1000): + sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`): + Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper + (https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model: + https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, *optional*, defaults to `1000`): The number of diffusion steps to train the model. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://huggingface.co/papers/2210.02303) paper). - rho (`float`, *optional*, defaults to 7.0): + prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`): + Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and + `"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper). + rho (`float`, *optional*, defaults to `7.0`): The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. - final_sigmas_type (`str`, defaults to `"zero"`): + final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. """ _compatibles = [] @@ -91,12 +90,12 @@ def __init__( sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, - sigma_schedule: str = "karras", + sigma_schedule: Literal["karras", "exponential"] = "karras", num_train_timesteps: int = 1000, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", rho: float = 7.0, - final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" - ): + final_sigmas_type: Literal["zero", "sigma_min"] = "zero", + ) -> None: if sigma_schedule not in ["karras", "exponential"]: raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`") @@ -131,26 +130,41 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def init_noise_sigma(self): - # standard deviation of the initial noise distribution + def init_noise_sigma(self) -> float: + """ + Return the standard deviation of the initial noise distribution. + + Returns: + `float`: + The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`. + """ return (self.config.sigma_max**2 + 1) ** 0.5 @property - def step_index(self): + def step_index(self) -> Optional[int]: """ - The index counter for current timestep. It will increase 1 after each scheduler step. + Return the index counter for the current timestep. The index will increase by 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index, or `None` if not yet initialized. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index` + method. + + Returns: + `int` or `None`: + The begin index, or `None` if not yet set. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -160,12 +174,36 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def precondition_inputs(self, sample, sigma): + def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the input sample by scaling it according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor to precondition. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The scaled input sample. + """ c_in = self._get_conditioning_c_in(sigma) scaled_sample = sample * c_in return scaled_sample - def precondition_noise(self, sigma): + def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the noise level by applying a logarithmic transformation. + + Args: + sigma (`float` or `torch.Tensor`): + The sigma (noise level) value to precondition. + + Returns: + `torch.Tensor`: + The preconditioned noise value computed as `0.25 * log(sigma)`. + """ if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) @@ -173,7 +211,27 @@ def precondition_noise(self, sigma): return c_noise - def precondition_outputs(self, sample, model_output, sigma): + def precondition_outputs( + self, + sample: torch.Tensor, + model_output: torch.Tensor, + sigma: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Precondition the model outputs according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor. + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The denoised sample computed by combining the skip connection and output scaling. + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) @@ -190,13 +248,13 @@ def precondition_outputs(self, sample, model_output, sigma): def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that + need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): - The input sample. - timestep (`int`, *optional*): + The input sample tensor. + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: @@ -214,19 +272,19 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, sigmas: Optional[Union[torch.Tensor, List[float]]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - sigmas (`Union[torch.Tensor, List[float]]`, *optional*): + sigmas (`torch.Tensor` or `List[float]`, *optional*): Custom sigmas to use for the denoising process. If not defined, the default behavior when `num_inference_steps` is passed will be used. """ @@ -262,8 +320,27 @@ def set_timesteps( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + def _compute_karras_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364). + + Args: + ramp (`torch.Tensor`): + A tensor of values in [0, 1] representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed Karras sigma schedule. + """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -273,10 +350,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch. sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Implementation closely follows k-diffusion. - + def _compute_exponential_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Compute the exponential sigma schedule. Implementation closely follows k-diffusion: https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + + Args: + ramp (`torch.Tensor`): + A tensor of values representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed exponential sigma schedule. """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -342,32 +436,38 @@ def step( generator: Optional[torch.Generator] = None, return_dict: bool = True, pred_original_sample: Optional[torch.Tensor] = None, - ) -> Union[EDMEulerSchedulerOutput, Tuple]: + ) -> Union[EDMEulerSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`float`): + The direct output from the learned diffusion model. + timestep (`float` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - s_churn (`float`): - s_tmin (`float`): - s_tmax (`float`): - s_noise (`float`, defaults to 1.0): + s_churn (`float`, *optional*, defaults to `0.0`): + The amount of stochasticity to add at each step. Higher values add more noise. + s_tmin (`float`, *optional*, defaults to `0.0`): + The minimum sigma threshold below which no noise is added. + s_tmax (`float`, *optional*, defaults to `float("inf")`): + The maximum sigma threshold above which no noise is added. + s_noise (`float`, *optional*, defaults to `1.0`): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple. + A random number generator for reproducibility. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple. + pred_original_sample (`torch.Tensor`, *optional*): + The predicted denoised sample from a previous step. If provided, skips recomputation. Returns: - [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. + [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`: + If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the previous sample tensor and the + second element is the predicted original sample tensor. """ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): @@ -399,7 +499,10 @@ def step( if gamma > 0: noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + model_output.shape, + dtype=model_output.dtype, + device=model_output.device, + generator=generator, ) eps = noise * s_noise sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 @@ -478,9 +581,20 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples - def _get_conditioning_c_in(self, sigma): + def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Compute the input conditioning factor for the EDM formulation. + + Args: + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `float` or `torch.Tensor`: + The input conditioning factor `c_in`. + """ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) return c_in - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 97fd84db5621..11fec60c9c0c 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -51,7 +51,7 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -77,6 +77,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index a55a76626cec..8b141325fbd3 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -54,7 +54,7 @@ class EulerDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -68,8 +68,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -80,6 +80,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py index 09341c909d2e..2bb6bf35585c 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py @@ -19,6 +19,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -28,6 +29,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class EulerDiscreteSchedulerState: common: CommonSchedulerState @@ -40,9 +44,18 @@ class EulerDiscreteSchedulerState: @classmethod def create( - cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + sigmas: jnp.ndarray, ): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + return cls( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) @dataclass @@ -99,6 +112,10 @@ def __init__( timestep_spacing: str = "linspace", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState: @@ -146,7 +163,10 @@ def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndar return sample def set_timesteps( - self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, + state: EulerDiscreteSchedulerState, + num_inference_steps: int, + shape: Tuple = (), ) -> EulerDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -159,7 +179,12 @@ def set_timesteps( """ if self.config.timestep_spacing == "linspace": - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + timesteps = jnp.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=self.dtype, + ) elif self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // num_inference_steps timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float) diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 9fd61d9e18d1..378a62ca8aee 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -171,8 +171,8 @@ def set_shift(self, shift: float): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - noise: Optional[torch.FloatTensor] = None, + timestep: torch.FloatTensor, + noise: torch.FloatTensor, ) -> torch.FloatTensor: """ Forward process in flow-matching @@ -180,8 +180,10 @@ def scale_noise( Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`): + The noise tensor. Returns: `torch.FloatTensor`: diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py index 6febee444c5a..6b85194f8b5e 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py @@ -110,8 +110,8 @@ def set_begin_index(self, begin_index: int = 0): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - noise: Optional[torch.FloatTensor] = None, + timestep: torch.FloatTensor, + noise: torch.FloatTensor, ) -> torch.FloatTensor: """ Forward process in flow-matching @@ -119,8 +119,10 @@ def scale_noise( Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`): + The noise tensor. Returns: `torch.FloatTensor`: @@ -130,6 +132,7 @@ def scale_noise( self._init_step_index(timestep) sigma = self.sigmas[self.step_index] + sample = sigma * noise + (1.0 - sigma) * sample return sample diff --git a/src/diffusers/schedulers/scheduling_flow_match_lcm.py b/src/diffusers/schedulers/scheduling_flow_match_lcm.py index 25186d1fe969..8ef0e2ec8175 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_lcm.py +++ b/src/diffusers/schedulers/scheduling_flow_match_lcm.py @@ -192,8 +192,8 @@ def set_scale_factors(self, scale_factors: list, upscale_mode): def scale_noise( self, sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - noise: Optional[torch.FloatTensor] = None, + timestep: torch.FloatTensor, + noise: torch.FloatTensor, ) -> torch.FloatTensor: """ Forward process in flow-matching @@ -201,8 +201,10 @@ def scale_noise( Args: sample (`torch.FloatTensor`): The input sample. - timestep (`int`, *optional*): + timestep (`torch.FloatTensor`): The current timestep in the diffusion chain. + noise (`torch.FloatTensor`): + The noise tensor. Returns: `torch.FloatTensor`: diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index b113f9b49832..0c5e28ad067d 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -51,7 +51,7 @@ class HeunDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -77,6 +77,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index da40bed635e1..ee49ae67b9cb 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -52,7 +52,7 @@ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -78,6 +78,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 6dc08d4d0a86..6effb3699b9a 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -51,7 +51,7 @@ class KDPM2DiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -65,8 +65,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -77,6 +77,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index bacfbd61006d..3f43a5fa9963 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -22,10 +22,13 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, logging from .scheduling_utils_flax import FlaxSchedulerMixin +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class KarrasVeSchedulerState: # setable values @@ -102,7 +105,10 @@ def __init__( s_min: float = 0.05, s_max: float = 50, ): - pass + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) def create_state(self): return KarrasVeSchedulerState.create() diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 0527f3533851..ada8806e8c73 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -53,7 +53,7 @@ class LCMSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -67,8 +67,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -79,6 +79,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -715,7 +722,7 @@ def previous_timestep(self, timestep): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 276af6eeacb7..a1f9d27fd938 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -49,7 +49,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -63,8 +63,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -75,6 +75,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 3fd4dc8a5d61..4edb091348c8 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,6 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -29,6 +30,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class LMSDiscreteSchedulerState: common: CommonSchedulerState @@ -44,9 +48,18 @@ class LMSDiscreteSchedulerState: @classmethod def create( - cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray + cls, + common: CommonSchedulerState, + init_noise_sigma: jnp.ndarray, + timesteps: jnp.ndarray, + sigmas: jnp.ndarray, ): - return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas) + return cls( + common=common, + init_noise_sigma=init_noise_sigma, + timesteps=timesteps, + sigmas=sigmas, + ) @dataclass @@ -101,6 +114,10 @@ def __init__( prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype def create_state(self, common: Optional[CommonSchedulerState] = None) -> LMSDiscreteSchedulerState: @@ -165,7 +182,10 @@ def lms_derivative(tau): return integrated_coeff def set_timesteps( - self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = () + self, + state: LMSDiscreteSchedulerState, + num_inference_steps: int, + shape: Tuple = (), ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -177,7 +197,12 @@ def set_timesteps( the number of diffusion steps used when generating samples with a pre-trained model. """ - timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype) + timesteps = jnp.linspace( + self.config.num_train_timesteps - 1, + 0, + num_inference_steps, + dtype=self.dtype, + ) low_idx = jnp.floor(timesteps).astype(jnp.int32) high_idx = jnp.ceil(timesteps).astype(jnp.int32) diff --git a/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py b/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py new file mode 100644 index 000000000000..6710254f4445 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py @@ -0,0 +1,386 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +LTXEulerAncestralRFScheduler + +This scheduler implements a K-diffusion style Euler-Ancestral sampler specialized for flow / CONST parameterization, +closely mirroring ComfyUI's `sample_euler_ancestral_RF` implementation used for LTX-Video. + +Reference implementation (ComfyUI): + comfy.k_diffusion.sampling.sample_euler_ancestral_RF +""" + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from ..utils.torch_utils import randn_tensor +from .scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LTXEulerAncestralRFSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor`): + Updated sample for the next step in the denoising process. + """ + + prev_sample: torch.FloatTensor + + +class LTXEulerAncestralRFScheduler(SchedulerMixin, ConfigMixin): + """ + Euler-Ancestral scheduler for LTX-Video (RF / CONST parametrization). + + This scheduler is intended for models where the network is trained with a CONST-like parameterization (as in LTXV / + FLUX). It approximates ComfyUI's `sample_euler_ancestral_RF` sampler and is useful when reproducing ComfyUI + workflows inside diffusers. + + The scheduler can either: + - reuse the [`FlowMatchEulerDiscreteScheduler`] sigma / timestep logic when only `num_inference_steps` is provided + (default diffusers-style usage), or + - follow an explicit ComfyUI-style sigma schedule when `sigmas` (or `timesteps`) are passed to [`set_timesteps`]. + + Args: + num_train_timesteps (`int`, defaults to 1000): + Included for config compatibility; not used to build the schedule. + eta (`float`, defaults to 1.0): + Stochasticity parameter. `eta=0.0` yields deterministic DDIM-like sampling; `eta=1.0` matches ComfyUI's + default RF behavior. + s_noise (`float`, defaults to 1.0): + Global scaling factor for the stochastic noise term. + """ + + # Allow config migration from the flow-match scheduler and back. + _compatibles = ["FlowMatchEulerDiscreteScheduler"] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + eta: float = 1.0, + s_noise: float = 1.0, + ): + # Note: num_train_timesteps is kept only for config compatibility. + self.num_inference_steps: Optional[int] = None + self.sigmas: Optional[torch.Tensor] = None + self.timesteps: Optional[torch.Tensor] = None + self._step_index: Optional[int] = None + self._begin_index: Optional[int] = None + + @property + def step_index(self) -> Optional[int]: + return self._step_index + + @property + def begin_index(self) -> Optional[int]: + """ + The index for the first timestep. It can be set from a pipeline with `set_begin_index` to support + image-to-image like workflows that start denoising part-way through the schedule. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Included for API compatibility; not strictly needed here but kept to allow pipelines that call + `set_begin_index`. + """ + self._begin_index = begin_index + + def index_for_timestep( + self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Map a (continuous) `timestep` value to an index into `self.timesteps`. + + This follows the convention used in other discrete schedulers: if the same timestep value appears multiple + times in the schedule (which can happen when starting in the middle of the schedule), the *second* occurrence + is used for the first `step` call so that no sigma is accidentally skipped. + """ + if schedule_timesteps is None: + if self.timesteps is None: + raise ValueError("Timesteps have not been set. Call `set_timesteps` first.") + schedule_timesteps = self.timesteps + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(schedule_timesteps.device) + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + if len(indices) == 0: + raise ValueError( + "Passed `timestep` is not in `self.timesteps`. Make sure to use values from `scheduler.timesteps`." + ) + + return indices[pos].item() + + def _init_step_index(self, timestep: Union[float, torch.Tensor]): + """ + Initialize the internal step index based on a given timestep. + """ + if self.timesteps is None: + raise ValueError("Timesteps have not been set. Call `set_timesteps` first.") + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def set_timesteps( + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device, None] = None, + sigmas: Optional[Union[List[float], torch.Tensor]] = None, + timesteps: Optional[Union[List[float], torch.Tensor]] = None, + mu: Optional[float] = None, + **kwargs, + ): + """ + Set the sigma / timestep schedule for sampling. + + When `sigmas` or `timesteps` are provided explicitly, they are used as the RF sigma schedule (ComfyUI-style) + and are expected to include the terminal 0.0. When both are `None`, the scheduler reuses the + [`FlowMatchEulerDiscreteScheduler`] logic to generate sigmas from `num_inference_steps` and the stored config + (including any resolution-dependent shifting, Karras/beta schedules, etc.). + + Args: + num_inference_steps (`int`, *optional*): + Number of denoising steps. If provided together with explicit `sigmas`/`timesteps`, they are expected + to be consistent and are otherwise ignored with a warning. + device (`str` or `torch.device`, *optional*): + Device to move the internal tensors to. + sigmas (`List[float]` or `torch.Tensor`, *optional*): + Explicit sigma schedule, e.g. `[1.0, 0.99, ..., 0.0]`. + timesteps (`List[float]` or `torch.Tensor`, *optional*): + Optional alias for `sigmas`. If `sigmas` is None and `timesteps` is provided, timesteps are treated as + sigmas. + mu (`float`, *optional*): + Optional shift parameter used when delegating to [`FlowMatchEulerDiscreteScheduler.set_timesteps`] and + `config.use_dynamic_shifting` is `True`. + """ + # 1. Auto-generate schedule (FlowMatch-style) when no explicit sigmas/timesteps are given + if sigmas is None and timesteps is None: + if num_inference_steps is None: + raise ValueError( + "LTXEulerAncestralRFScheduler.set_timesteps requires either explicit `sigmas`/`timesteps` " + "or a `num_inference_steps` value." + ) + + # We reuse FlowMatchEulerDiscreteScheduler to construct a sigma schedule that is + # consistent with the original LTX training setup (including optional time shifting, + # Karras / exponential / beta schedules, etc.). + from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + + base_scheduler = FlowMatchEulerDiscreteScheduler.from_config(self.config) + base_scheduler.set_timesteps( + num_inference_steps=num_inference_steps, + device=device, + sigmas=None, + mu=mu, + timesteps=None, + ) + + self.num_inference_steps = base_scheduler.num_inference_steps + # Keep sigmas / timesteps on the requested device so step() can operate on-device without + # extra transfers. + self.sigmas = base_scheduler.sigmas.to(device=device) + self.timesteps = base_scheduler.timesteps.to(device=device) + self._step_index = None + self._begin_index = None + return + + # 2. Explicit sigma schedule (ComfyUI-style path) + if sigmas is None: + # `timesteps` is treated as sigmas in RF / flow-matching setups. + sigmas = timesteps + + if isinstance(sigmas, list): + sigmas_tensor = torch.tensor(sigmas, dtype=torch.float32) + elif isinstance(sigmas, torch.Tensor): + sigmas_tensor = sigmas.to(dtype=torch.float32) + else: + raise TypeError(f"`sigmas` must be a list or torch.Tensor, got {type(sigmas)}.") + + if sigmas_tensor.ndim != 1: + raise ValueError(f"`sigmas` must be a 1D tensor, got shape {tuple(sigmas_tensor.shape)}.") + + if sigmas_tensor[-1].abs().item() > 1e-6: + logger.warning( + "The last sigma in the schedule is not zero (%.6f). " + "For best compatibility with ComfyUI's RF sampler, the terminal sigma " + "should be 0.0.", + sigmas_tensor[-1].item(), + ) + + # Move to device once, then derive timesteps. + if device is not None: + sigmas_tensor = sigmas_tensor.to(device) + + # Internal sigma schedule stays in [0, 1] (as provided). + self.sigmas = sigmas_tensor + # Timesteps are scaled to match the training setup of LTX (FlowMatch-style), + # where the network expects timesteps on [0, num_train_timesteps]. + # This keeps the transformer conditioning in the expected range while the RF + # scheduler still operates on the raw sigma values. + num_train = float(getattr(self.config, "num_train_timesteps", 1000)) + self.timesteps = sigmas_tensor * num_train + + if num_inference_steps is not None and num_inference_steps != len(sigmas) - 1: + logger.warning( + "Provided `num_inference_steps=%d` does not match `len(sigmas)-1=%d`. " + "Overriding `num_inference_steps` with `len(sigmas)-1`.", + num_inference_steps, + len(sigmas) - 1, + ) + + self.num_inference_steps = len(sigmas) - 1 + self._step_index = None + self._begin_index = None + + def _sigma_broadcast(self, sigma: torch.Tensor, sample: torch.Tensor) -> torch.Tensor: + """ + Helper to broadcast a scalar sigma to the shape of `sample`. + """ + while sigma.ndim < sample.ndim: + sigma = sigma.view(*sigma.shape, 1) + return sigma + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.Tensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[LTXEulerAncestralRFSchedulerOutput, Tuple[torch.FloatTensor]]: + """ + Perform a single Euler-Ancestral RF update step. + + Args: + model_output (`torch.FloatTensor`): + Raw model output at the current step. Interpreted under the CONST parametrization as `v_t`, with + denoised state reconstructed as `x0 = x_t - sigma_t * v_t`. + timestep (`float` or `torch.Tensor`): + The current sigma value (must match one entry in `self.timesteps`). + sample (`torch.FloatTensor`): + Current latent sample `x_t`. + generator (`torch.Generator`, *optional*): + Optional generator for reproducible noise. + return_dict (`bool`): + If `True`, return a `LTXEulerAncestralRFSchedulerOutput`; otherwise return a tuple where the first + element is the updated sample. + """ + + if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `LTXEulerAncestralRFScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` values as `timestep`." + ), + ) + + if self.sigmas is None or self.timesteps is None: + raise ValueError("Scheduler has not been initialized. Call `set_timesteps` before `step`.") + + if self._step_index is None: + self._init_step_index(timestep) + + i = self._step_index + if i >= len(self.sigmas) - 1: + # Already at the end; simply return the current sample. + prev_sample = sample + else: + # Work in float32 for numerical stability + sample_f = sample.to(torch.float32) + model_output_f = model_output.to(torch.float32) + + sigma = self.sigmas[i] + sigma_next = self.sigmas[i + 1] + + sigma_b = self._sigma_broadcast(sigma.view(1), sample_f) + sigma_next_b = self._sigma_broadcast(sigma_next.view(1), sample_f) + + # Approximate denoised x0 under CONST parametrization: + # x0 = x_t - sigma_t * v_t + denoised = sample_f - sigma_b * model_output_f + + if sigma_next.abs().item() < 1e-8: + # Final denoising step + x = denoised + else: + eta = float(self.config.eta) + s_noise = float(self.config.s_noise) + + # Downstep computation (ComfyUI RF variant) + downstep_ratio = 1.0 + (sigma_next / sigma - 1.0) * eta + sigma_down = sigma_next * downstep_ratio + + alpha_ip1 = 1.0 - sigma_next + alpha_down = 1.0 - sigma_down + + # Deterministic part (Euler step in (x, x0)-space) + sigma_down_b = self._sigma_broadcast(sigma_down.view(1), sample_f) + alpha_ip1_b = self._sigma_broadcast(alpha_ip1.view(1), sample_f) + alpha_down_b = self._sigma_broadcast(alpha_down.view(1), sample_f) + + sigma_ratio = sigma_down_b / sigma_b + x = sigma_ratio * sample_f + (1.0 - sigma_ratio) * denoised + + # Stochastic ancestral noise + if eta > 0.0 and s_noise > 0.0: + renoise_coeff = ( + (sigma_next_b**2 - sigma_down_b**2 * alpha_ip1_b**2 / (alpha_down_b**2 + 1e-12)) + .clamp(min=0.0) + .sqrt() + ) + + noise = randn_tensor( + sample_f.shape, generator=generator, device=sample_f.device, dtype=sample_f.dtype + ) + x = (alpha_ip1_b / (alpha_down_b + 1e-12)) * x + noise * renoise_coeff * s_noise + + prev_sample = x.to(sample.dtype) + + # Advance internal step index + self._step_index = min(self._step_index + 1, len(self.sigmas) - 1) + + if not return_dict: + return (prev_sample,) + + return LTXEulerAncestralRFSchedulerOutput(prev_sample=prev_sample) + + def __len__(self) -> int: + # For compatibility with other schedulers; used e.g. in some training + # utilities to infer the maximum number of training timesteps. + return int(getattr(self.config, "num_train_timesteps", 1000)) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 651532b06ddb..0820f5baa871 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -28,7 +28,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -42,8 +42,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -54,6 +54,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 44bafccd5520..bbef4649ecb5 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils_flax import ( CommonSchedulerState, FlaxKarrasDiffusionSchedulers, @@ -31,6 +32,9 @@ ) +logger = logging.get_logger(__name__) + + @flax.struct.dataclass class PNDMSchedulerState: common: CommonSchedulerState @@ -131,6 +135,10 @@ def __init__( prediction_type: str = "epsilon", dtype: jnp.dtype = jnp.float32, ): + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) self.dtype = dtype # For now we only support F-PNDM, i.e. the runge-kutta method @@ -190,7 +198,10 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha else: prk_timesteps = _timesteps[-self.pndm_order :].repeat(2) + jnp.tile( - jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2], dtype=jnp.int32), + jnp.array( + [0, self.config.num_train_timesteps // num_inference_steps // 2], + dtype=jnp.int32, + ), self.pndm_order, ) @@ -218,7 +229,10 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha ) def scale_model_input( - self, state: PNDMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None + self, + state: PNDMSchedulerState, + sample: jnp.ndarray, + timestep: Optional[int] = None, ) -> jnp.ndarray: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -320,7 +334,9 @@ def step_prk( ) diff_to_prev = jnp.where( - state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2 + state.counter % 2, + 0, + self.config.num_train_timesteps // state.num_inference_steps // 2, ) prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] @@ -401,7 +417,9 @@ def step_plms( prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) timestep = jnp.where( - state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep + state.counter == 1, + timestep + self.config.num_train_timesteps // state.num_inference_steps, + timestep, ) # Reference: @@ -466,7 +484,9 @@ def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_tim # prev_sample -> x_(t−δ) alpha_prod_t = state.common.alphas_cumprod[timestep] alpha_prod_t_prev = jnp.where( - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod + prev_timestep >= 0, + state.common.alphas_cumprod[prev_timestep], + state.final_alpha_cumprod, ) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index a2eaf8eb3abd..bec4a1bdf652 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -47,7 +47,7 @@ class RePaintSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -61,8 +61,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -73,6 +73,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 5783e20de69d..565fae1c0d76 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -35,7 +35,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -49,8 +49,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -61,6 +61,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 09cd081462b3..f4fe6d8f6bbf 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -23,7 +23,15 @@ from jax import random from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from ..utils import logging +from .scheduling_utils_flax import ( + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) + + +logger = logging.get_logger(__name__) @flax.struct.dataclass @@ -95,7 +103,10 @@ def __init__( sampling_eps: float = 1e-5, correct_steps: int = 1, ): - pass + logger.warning( + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " + "recommend migrating to PyTorch classes or pinning your version of Diffusers." + ) def create_state(self): state = ScoreSdeVeSchedulerState.create() @@ -108,7 +119,11 @@ def create_state(self): ) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None + self, + state: ScoreSdeVeSchedulerState, + num_inference_steps: int, + shape: Tuple = (), + sampling_eps: float = None, ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 7b4840ffdb19..a1303436cd8d 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -52,7 +52,7 @@ class TCDSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -66,8 +66,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -78,6 +78,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -770,7 +777,7 @@ def previous_timestep(self, timestep): The current timestep. Returns: - `int`: + `int` or `torch.Tensor`: The previous timestep. """ if self.custom_timesteps or self.num_inference_steps: diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index 5a978dec649b..14b09277da04 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -48,7 +48,7 @@ class UnCLIPSchedulerOutput(BaseOutput): def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -62,8 +62,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -74,6 +74,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 6800c1220177..d8e24d196418 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -34,7 +34,7 @@ def betas_for_alpha_bar( num_diffusion_timesteps: int, max_beta: float = 0.999, - alpha_transform_type: Literal["cosine", "exp"] = "cosine", + alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine", ) -> torch.Tensor: """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -48,8 +48,8 @@ def betas_for_alpha_bar( The number of betas to produce. max_beta (`float`, defaults to `0.999`): The maximum beta to use; use values lower than 1 to avoid numerical instability. - alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`): - The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`. + alpha_transform_type (`str`, defaults to `"cosine"`): + The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`. Returns: `torch.Tensor`: @@ -60,6 +60,13 @@ def betas_for_alpha_bar( def alpha_bar_fn(t): return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == "laplace": + + def alpha_bar_fn(t): + lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6) + snr = math.exp(lmb) + return math.sqrt(snr / (1 + snr)) + elif alpha_transform_type == "exp": def alpha_bar_fn(t): @@ -77,7 +84,7 @@ def alpha_bar_fn(t): # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): +def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor: """ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) @@ -127,19 +134,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. - solver_order (`int`, default `2`): + solver_order (`int`, defaults to `2`): The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://huggingface.co/papers/2210.02303) paper). + `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. @@ -149,7 +156,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. predict_x0 (`bool`, defaults to `True`): Whether to use the updating algorithm on the predicted x0. - solver_type (`str`, default `bh2`): + solver_type (`"bh1"` or `"bh2"`, defaults to `"bh2"`): Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` otherwise. lower_order_final (`bool`, default `True`): @@ -171,12 +178,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. use_flow_sigmas (`bool`, *optional*, defaults to `False`): Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. - timestep_spacing (`str`, defaults to `"linspace"`): + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): An offset added to the inference steps, as required by some model families. - final_sigmas_type (`str`, defaults to `"zero"`): + final_sigmas_type (`"zero"` or `"sigma_min"`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. rescale_betas_zero_snr (`bool`, defaults to `False`): @@ -194,30 +201,33 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, predict_x0: bool = True, - solver_type: str = "bh2", + solver_type: Literal["bh1", "bh2"] = "bh2", lower_order_final: bool = True, disable_corrector: List[int] = [], - solver_p: SchedulerMixin = None, + solver_p: Optional[SchedulerMixin] = None, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, - timestep_spacing: str = "linspace", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", rescale_betas_zero_snr: bool = False, use_dynamic_shifting: bool = False, - time_shift_type: str = "exponential", - ): + time_shift_type: Literal["exponential"] = "exponential", + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + shift_terminal: Optional[float] = None, + ) -> None: if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: @@ -236,6 +246,8 @@ def __init__( self.betas = betas_for_alpha_bar(num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + if shift_terminal is not None and not use_flow_sigmas: + raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.") if rescale_betas_zero_snr: self.betas = rescale_zero_terminal_snr(self.betas) @@ -279,21 +291,21 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def step_index(self): + def step_index(self) -> Optional[int]: """ The index counter for current timestep. It will increase 1 after each scheduler step. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -304,7 +316,11 @@ def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index def set_timesteps( - self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None + self, + num_inference_steps: Optional[int] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -314,11 +330,24 @@ def set_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Optional mu parameter for dynamic shifting when using exponential time shift type. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None: + if not self.config.use_flow_sigmas: + raise ValueError( + "Passing `sigmas` is only supported when `use_flow_sigmas=True`. " + "Please set `use_flow_sigmas=True` during scheduler initialization." + ) + num_inference_steps = len(sigmas) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 - if mu is not None: - assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential" - self.config.flow_shift = np.exp(mu) if self.config.timestep_spacing == "linspace": timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) @@ -343,12 +372,18 @@ def set_timesteps( f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." ) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) if self.config.use_karras_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + if self.config.use_flow_sigmas: + sigmas = sigmas / (sigmas + 1) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + else: + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] elif self.config.final_sigmas_type == "zero": @@ -359,6 +394,8 @@ def set_timesteps( ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_exponential_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -373,6 +410,8 @@ def set_timesteps( ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_beta_sigmas: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) sigmas = np.flip(sigmas).copy() sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) @@ -387,9 +426,18 @@ def set_timesteps( ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) elif self.config.use_flow_sigmas: - alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) - sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1] + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas) + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + eps = 1e-6 + if np.fabs(sigmas[0] - 1) < eps: + # to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update + sigmas[0] -= eps timesteps = (sigmas * self.config.num_train_timesteps).copy() if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] @@ -401,6 +449,8 @@ def set_timesteps( ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: + if sigmas is None: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -430,6 +480,43 @@ def set_timesteps( self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: """ @@ -475,7 +562,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -512,7 +599,7 @@ def _sigma_to_t(self, sigma, log_sigmas): return t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Convert sigma values to alpha_t and sigma_t values. @@ -534,7 +621,7 @@ def _sigma_to_alpha_sigma_t(self, sigma): return alpha_t, sigma_t # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: """ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). @@ -1030,7 +1117,7 @@ def index_for_timestep( return step_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. @@ -1060,11 +1147,11 @@ def step( Args: model_output (`torch.Tensor`): The direct output from learned diffusion model. - timestep (`int`): + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: @@ -1192,5 +1279,5 @@ def add_noise( noisy_samples = alpha_t * original_samples + sigma_t * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 7a98fa3da14a..3e9968d47fdd 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -6,11 +6,18 @@ import re import warnings from contextlib import contextmanager -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch + +if getattr(torch, "distributed", None) is not None: + from torch.distributed.fsdp import CPUOffload, ShardingStrategy + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline from .schedulers import SchedulerMixin @@ -18,6 +25,7 @@ convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, + is_accelerate_available, is_peft_available, is_torch_npu_available, is_torchvision_available, @@ -31,6 +39,9 @@ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed +if is_accelerate_available(): + from accelerate.logging import get_logger + if is_peft_available(): from peft import set_peft_model_state_dict @@ -394,6 +405,86 @@ def find_nearest_bucket(h, w, bucket_options): return best_bucket_idx +def _to_cpu_contiguous(state_dicts) -> dict: + return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()} + + +def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: + """ + Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs. + """ + + kwargs = {} + fsdp_state = getattr(accelerator.state, "fsdp_plugin", None) + + if fsdp_state is None: + raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.") + + fsdp_plugin = accelerator.state.fsdp_plugin + + if fsdp_plugin is None: + # FSDP not enabled in Accelerator + kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD + else: + # FSDP is enabled → use plugin's strategy, or default if None + kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD + + return kwargs + + +def wrap_with_fsdp( + model: torch.nn.Module, + device: Union[str, torch.device], + offload: bool = True, + use_orig_params: bool = True, + limit_all_gathers: bool = True, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, +) -> FSDP: + """ + Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. + + Args: + model: Model to wrap + device: Target device (e.g., accelerator.device) + offload: Whether to enable CPU parameter offloading + use_orig_params: Whether to use original parameters + limit_all_gathers: Whether to limit all gathers + fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config + transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs) + + Returns: + FSDP-wrapped model + """ + + logger = get_logger(__name__) + + if transformer_layer_cls is None: + # Set the default layers if transformer_layer_cls is not provided + transformer_layer_cls = type(model.model.language_model.layers[0]) + logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}") + + # Add auto-wrap policy if transformer layers specified + auto_wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={transformer_layer_cls}, + ) + + config = { + "device_id": device, + "cpu_offload": CPUOffload(offload_params=offload) if offload else None, + "use_orig_params": use_orig_params, + "limit_all_gathers": limit_all_gathers, + "auto_wrap_policy": auto_wrap_policy, + } + + if fsdp_kwargs: + config.update(fsdp_kwargs) + + fsdp_model = FSDP(model, **config) + return fsdp_model + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 6884d3be9292..3f736e2ee39b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -23,6 +23,7 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS, DEPRECATED_REVISION_ARGS, DIFFUSERS_DYNAMIC_MODULE_NAME, + DIFFUSERS_LOAD_ID_FIELDS, FLAX_WEIGHTS_NAME, GGUF_FILE_EXTENSION, HF_ENABLE_PARALLEL_LOADING, @@ -66,6 +67,7 @@ is_accelerate_version, is_aiter_available, is_aiter_version, + is_av_available, is_better_profanity_available, is_bitsandbytes_available, is_bitsandbytes_version, @@ -143,6 +145,7 @@ from .remote_utils import remote_decode from .state_dict_utils import ( convert_all_state_dict_to_peft, + convert_sai_sd_control_lora_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_state_dict_to_peft, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index c46fa4363483..4f94df656a65 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -73,3 +73,11 @@ ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/" + + +DIFFUSERS_LOAD_ID_FIELDS = [ + "pretrained_model_name_or_path", + "subfolder", + "variant", + "revision", +] diff --git a/src/diffusers/utils/distributed_utils.py b/src/diffusers/utils/distributed_utils.py new file mode 100644 index 000000000000..239b7b26200d --- /dev/null +++ b/src/diffusers/utils/distributed_utils.py @@ -0,0 +1,36 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import torch +except ImportError: + torch = None + + +def is_torch_dist_rank_zero() -> bool: + if torch is None: + return True + + dist_module = getattr(torch, "distributed", None) + if dist_module is None or not dist_module.is_available(): + return True + + if not dist_module.is_initialized(): + return True + + try: + return dist_module.get_rank() == 0 + except (RuntimeError, ValueError): + return True diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6be7618fcd5e..6c436161c5a7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,6 +227,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class MagCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] @@ -257,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class TaylorSeerCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) @@ -269,10 +299,18 @@ def apply_layer_skip(*args, **kwargs): requires_backends(apply_layer_skip, ["torch"]) +def apply_mag_cache(*args, **kwargs): + requires_backends(apply_mag_cache, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) +def apply_taylorseer_cache(*args, **kwargs): + requires_backends(apply_taylorseer_cache, ["torch"]) + + class AllegroTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -483,6 +521,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLLTX2Audio(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class AutoencoderKLLTX2Video(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -933,6 +1001,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class GlmImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HiDreamImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1113,6 +1196,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LongCatImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class LTX2VideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1743,6 +1856,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ZImageControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ZImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -2585,6 +2713,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTXEulerAncestralRFScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PNDMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b62bfa734e3b..a23f852616c0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,81 @@ from ..utils import DummyObject, requires_backends +class Flux2AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2KleinAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2KleinBaseAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2KleinModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Flux2ModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -137,6 +212,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class QwenImageLayeredAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class QwenImageLayeredModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -227,6 +332,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ZImageAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ZImageModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AllegroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -497,6 +632,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class BriaFiboEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class BriaFiboPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -542,6 +692,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ChromaInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ChromaPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -707,6 +872,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Cosmos2_5_PredictBasePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Cosmos2TextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -827,6 +1007,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1037,6 +1232,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class GlmImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HiDreamImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1772,6 +1982,81 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LongCatImageEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LongCatImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2LatentUpsamplePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTX2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXConditionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -1787,6 +2072,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXI2VLongMultiPromptPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LTXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2207,6 +2507,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class QwenImageLayeredPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -3752,6 +4067,66 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ZImageControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ZImageControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ZImageImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ZImageOmniPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ZImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index d0b05c7d9541..58695bae1e9d 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -107,6 +107,7 @@ def load_or_create_model_card( license: Optional[str] = None, widget: Optional[List[dict]] = None, inference: Optional[bool] = None, + is_modular: bool = False, ) -> ModelCard: """ Loads or creates a model card. @@ -131,6 +132,8 @@ def load_or_create_model_card( widget (`List[dict]`, *optional*): Widget to accompany a gallery template. inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using `load_or_create_model_card` from a training script. + is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline. + When True, uses model_description as-is without additional template formatting. """ if not is_jinja_available(): raise ValueError( @@ -159,10 +162,14 @@ def load_or_create_model_card( ) else: card_data = ModelCardData() - component = "pipeline" if is_pipeline else "model" - if model_description is None: - model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." - model_card = ModelCard.from_template(card_data, model_description=model_description) + if is_modular and model_description is not None: + model_card = ModelCard(model_description) + model_card.data = card_data + else: + component = "pipeline" if is_pipeline else "model" + if model_description is None: + model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated." + model_card = ModelCard.from_template(card_data, model_description=model_description) return model_card diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 57b0a337922a..e35fc9697076 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -227,9 +227,10 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") -_aiter_available, _aiter_version = _is_package_available("aiter") +_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True) _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) +_av_available, _av_version = _is_package_available("av") def is_torch_available(): @@ -420,6 +421,10 @@ def is_kornia_available(): return _kornia_available +def is_av_available(): + return _av_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 2ad6d3a47607..80e108e4a6ff 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -32,6 +32,8 @@ from tqdm import auto as tqdm_lib +from .distributed_utils import is_torch_dist_rank_zero + _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None @@ -47,6 +49,23 @@ _default_log_level = logging.WARNING _tqdm_active = True +_rank_zero_filter = None + + +class _RankZeroFilter(logging.Filter): + def filter(self, record): + # Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting. + return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG + + +def _ensure_rank_zero_filter(logger: logging.Logger) -> None: + global _rank_zero_filter + + if _rank_zero_filter is None: + _rank_zero_filter = _RankZeroFilter() + + if not any(isinstance(f, _RankZeroFilter) for f in logger.filters): + logger.addFilter(_rank_zero_filter) def _get_default_logging_level() -> int: @@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None: library_root_logger.addHandler(_default_handler) library_root_logger.setLevel(_get_default_logging_level()) library_root_logger.propagate = False + _ensure_rank_zero_filter(library_root_logger) def _reset_library_root_logger() -> None: @@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: name = _get_library_name() _configure_library_root_logger() - return logging.getLogger(name) + logger = logging.getLogger(name) + _ensure_rank_zero_filter(logger) + return logger def get_verbosity() -> int: diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 50bfce8b15eb..c9bf5fec289b 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -56,6 +56,36 @@ class StateDictType(enum.Enum): ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector", } +CONTROL_LORA_TO_DIFFUSERS = { + ".to_q.down": ".to_q.lora_A.weight", + ".to_q.up": ".to_q.lora_B.weight", + ".to_k.down": ".to_k.lora_A.weight", + ".to_k.up": ".to_k.lora_B.weight", + ".to_v.down": ".to_v.lora_A.weight", + ".to_v.up": ".to_v.lora_B.weight", + ".to_out.0.down": ".to_out.0.lora_A.weight", + ".to_out.0.up": ".to_out.0.lora_B.weight", + ".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight", + ".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight", + ".ff.net.2.down": ".ff.net.2.lora_A.weight", + ".ff.net.2.up": ".ff.net.2.lora_B.weight", + ".proj_in.down": ".proj_in.lora_A.weight", + ".proj_in.up": ".proj_in.lora_B.weight", + ".proj_out.down": ".proj_out.lora_A.weight", + ".proj_out.up": ".proj_out.lora_B.weight", + ".conv.down": ".conv.lora_A.weight", + ".conv.up": ".conv.lora_B.weight", + **{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)}, + **{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)}, + "conv_in.down": "conv_in.lora_A.weight", + "conv_in.up": "conv_in.lora_B.weight", + ".conv_shortcut.down": ".conv_shortcut.lora_A.weight", + ".conv_shortcut.up": ".conv_shortcut.lora_B.weight", + **{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)}, + **{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)}, + "time_emb_proj.down": "time_emb_proj.lora_A.weight", + "time_emb_proj.up": "time_emb_proj.lora_B.weight", +} DIFFUSERS_TO_PEFT = { ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", @@ -259,6 +289,155 @@ def convert_unet_state_dict_to_peft(state_dict): return convert_state_dict(state_dict, mapping) +def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): + def _convert_controlnet_to_diffusers(state_dict): + is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict + logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})") + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + layers_per_block = 2 + + # op blocks + op_blocks = [key for key in state_dict if "0.op" in key] + + converted_state_dict = {} + # Conv in layers + for key in input_blocks[0]: + diffusers_key = key.replace("input_blocks.0.0", "conv_in") + converted_state_dict[diffusers_key] = state_dict.get(key) + + # controlnet time embedding blocks + time_embedding_blocks = [key for key in state_dict if "time_embed" in key] + for key in time_embedding_blocks: + diffusers_key = key.replace("time_embed.0", "time_embedding.linear_1").replace( + "time_embed.2", "time_embedding.linear_2" + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + # controlnet label embedding blocks + label_embedding_blocks = [key for key in state_dict if "label_emb" in key] + for key in label_embedding_blocks: + diffusers_key = key.replace("label_emb.0.0", "add_embedding.linear_1").replace( + "label_emb.0.2", "add_embedding.linear_2" + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (layers_per_block + 1) + layer_in_block_id = (i - 1) % (layers_per_block + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + for key in resnets: + diffusers_key = ( + key.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + diffusers_key = diffusers_key.replace( + f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}" + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + if f"input_blocks.{i}.0.op.bias" in state_dict: + for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]: + diffusers_key = key.replace( + f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv" + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + for key in attentions: + diffusers_key = key.replace( + f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}" + ) + converted_state_dict[diffusers_key] = state_dict.get(key) + + # controlnet down blocks + for i in range(num_input_blocks): + converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight") + converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Mid blocks + for key in middle_blocks.keys(): + diffusers_key = max(key - 1, 0) + if key % 2 == 0: + for k in middle_blocks[key]: + diffusers_key_hf = ( + k.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + diffusers_key_hf = diffusers_key_hf.replace( + f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}" + ) + converted_state_dict[diffusers_key_hf] = state_dict.get(k) + else: + for k in middle_blocks[key]: + diffusers_key_hf = k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}") + converted_state_dict[diffusers_key_hf] = state_dict.get(k) + + # mid block + converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight") + converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in state_dict + if "input_hint_block" in layer + and ("input_hint_block.0" not in layer) + and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get( + f"input_hint_block.{cond_block_id}.weight" + ) + converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get( + f"input_hint_block.{cond_block_id}.bias" + ) + + for key in [key for key in state_dict if "input_hint_block.0" in key]: + diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in") + converted_state_dict[diffusers_key] = state_dict.get(key) + + for key in [key for key in state_dict if "input_hint_block.14" in key]: + diffusers_key = key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out") + converted_state_dict[diffusers_key] = state_dict.get(key) + + return converted_state_dict + + state_dict = _convert_controlnet_to_diffusers(state_dict) + mapping = CONTROL_LORA_TO_DIFFUSERS + return convert_state_dict(state_dict, mapping) + + def convert_all_state_dict_to_peft(state_dict): r""" Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid diff --git a/tests/conftest.py b/tests/conftest.py index fd76d1c84ee7..1d7e83467f03 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,24 @@ def pytest_configure(config): config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources") + config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality") + config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality") + config.addinivalue_line("markers", "training: marks tests for training functionality") + config.addinivalue_line("markers", "attention: marks tests for attention processor functionality") + config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality") + config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality") + config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality") + config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality") + config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading") + config.addinivalue_line("markers", "quantization: marks tests for quantization functionality") + config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality") + config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality") + config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality") + config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality") + config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality") + config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality") + config.addinivalue_line("markers", "slow: mark test as slow") + config.addinivalue_line("markers", "nightly: mark test as nightly") def pytest_addoption(parser): diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 96cbecfbf530..236094109d07 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -19,6 +19,7 @@ import torch from parameterized import parameterized +from diffusers import AutoencoderKL from diffusers.hooks import HookRegistry, ModelHook from diffusers.models import ModelMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline @@ -149,6 +150,74 @@ def post_forward(self, module, output): return output +# Model with only standalone computational layers at top level +class DummyModelWithStandaloneLayers(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.layer1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.layer2 = torch.nn.Linear(hidden_features, hidden_features) + self.layer3 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.layer1(x) + x = self.activation(x) + x = self.layer2(x) + x = self.layer3(x) + return x + + +# Model with deeply nested structure +class DummyModelWithDeeplyNestedBlocks(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.input_layer = torch.nn.Linear(in_features, hidden_features) + self.container = ContainerWithNestedModuleList(hidden_features) + self.output_layer = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.input_layer(x) + x = self.container(x) + x = self.output_layer(x) + return x + + +class ContainerWithNestedModuleList(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + # Top-level computational layer + self.proj_in = torch.nn.Linear(features, features) + + # Nested container with ModuleList + self.nested_container = NestedContainer(features) + + # Another top-level computational layer + self.proj_out = torch.nn.Linear(features, features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.nested_container(x) + x = self.proj_out(x) + return x + + +class NestedContainer(torch.nn.Module): + def __init__(self, features: int) -> None: + super().__init__() + + self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)]) + self.norm = torch.nn.LayerNorm(features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + x = self.norm(x) + return x + + @require_torch_accelerator class GroupOffloadTests(unittest.TestCase): in_features = 64 @@ -340,7 +409,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): out = model(x) self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.") - num_repeats = 4 + num_repeats = 2 for i in range(num_repeats): out_ref = model_ref(x) out = model(x) @@ -362,3 +431,138 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): self.assertLess( cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" ) + + def test_vae_like_model_without_streams(self): + """Test VAE-like model with block-level offloading but without streams.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) + + model_ref = AutoencoderKL(**config) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False) + + x = torch.randn(2, 3, 32, 32).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x).sample + out = model(x).sample + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams." + ) + + def test_model_with_only_standalone_layers(self): + """Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(2): + out_ref = model_ref(x) + out = model(x) + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match at iteration {i} for model with standalone layers.", + ) + + @parameterized.expand([("block_level",), ("leaf_level",)]) + def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str): + """Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) + + model_ref = AutoencoderKL(**config) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 3, 32, 32).to(torch_device) + + with torch.no_grad(): + out_ref = model_ref(x).sample + out = model(x).sample + + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match for standalone Conv layers with {offload_type}.", + ) + + def test_multiple_invocations_with_vae_like_model(self): + """Test that multiple forward passes work correctly with VAE-like model.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + config = self.get_autoencoder_kl_config() + model = AutoencoderKL(**config) + + model_ref = AutoencoderKL(**config) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 3, 32, 32).to(torch_device) + + with torch.no_grad(): + for i in range(2): + out_ref = model_ref(x).sample + out = model(x).sample + self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.") + + def test_nested_container_parameters_offloading(self): + """Test that parameters from non-computational layers in nested containers are handled correctly.""" + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + + model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64) + model_ref.load_state_dict(model.state_dict(), strict=True) + model_ref.to(torch_device) + + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + x = torch.randn(2, 64).to(torch_device) + + with torch.no_grad(): + for i in range(2): + out_ref = model_ref(x) + out = model(x) + self.assertTrue( + torch.allclose(out_ref, out, atol=1e-5), + f"Outputs do not match at iteration {i} for nested parameters.", + ) + + def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [2, 4] + norm_num_groups = norm_num_groups or 2 + init_dict = { + "block_out_channels": block_out_channels, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), + "latent_channels": 4, + "norm_num_groups": norm_num_groups, + "layers_per_block": 1, + } + return init_dict diff --git a/tests/hooks/test_mag_cache.py b/tests/hooks/test_mag_cache.py new file mode 100644 index 000000000000..a7e1b52d3b69 --- /dev/null +++ b/tests/hooks/test_mag_cache.py @@ -0,0 +1,244 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import MagCacheConfig, apply_mag_cache +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.models import ModelMixin +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) + + +class DummyBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Output is double input + # This ensures Residual = 2*Input - Input = Input + return hidden_states * 2.0 + + +class DummyTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states + + +class TupleOutputBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Returns a tuple + return hidden_states * 2.0, encoder_hidden_states + + +class TupleTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + # Emulate Flux-like behavior + output = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = output[0] + encoder_hidden_states = output[1] + return hidden_states, encoder_hidden_states + + +class MagCacheTests(unittest.TestCase): + def setUp(self): + # Register standard dummy block + TransformerBlockRegistry.register( + DummyBlock, + TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), + ) + # Register tuple block (Flux style) + TransformerBlockRegistry.register( + TupleOutputBlock, + TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), + ) + + def _set_context(self, model, context_name): + """Helper to set context on all hooks in the model.""" + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(context_name) + + def _get_calibration_data(self, model): + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("mag_cache_block_hook") + if hook: + return hook.state_manager.get_state().calibration_ratios + return [] + + def test_mag_cache_validation(self): + """Test that missing mag_ratios raises ValueError.""" + with self.assertRaises(ValueError): + MagCacheConfig(num_inference_steps=10, calibrate=False) + + def test_mag_cache_skipping_logic(self): + """ + Tests that MagCache correctly calculates residuals and skips blocks when conditions are met. + """ + model = DummyTransformer() + + # Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, # Enable immediate skipping + max_skip_steps=5, + mag_ratios=ratios, + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each) + # HeadInput=10. Output=40. Residual=30. + input_t0 = torch.tensor([[[10.0]]]) + output_t0 = model(input_t0) + self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed") + + # Step 1: Input 11.0. + # If Skipped: Output = Input(11) + Residual(30) = 41.0 + # If Computed: Output = 11 * 4 = 44.0 + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}" + ) + + def test_mag_cache_retention(self): + """Test that retention_ratio prevents skipping even if error is low.""" + model = DummyTransformer() + # Ratios that imply 0 error, so it *would* skip if retention allowed it + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=1.0, # Force retention for ALL steps + mag_ratios=ratios, + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + model(torch.tensor([[[10.0]]])) + + # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[44.0]]])), + f"Expected Compute (44.0) due to retention, got {output_t1.item()}", + ) + + def test_mag_cache_tuple_outputs(self): + """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" + model = TupleTransformer() + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) + # Residual = 10.0 + input_t0 = torch.tensor([[[10.0]]]) + enc_t0 = torch.tensor([[[1.0]]]) + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) + self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]]))) + + # Step 1: Skip. Input 11.0. + # Skipped Output = 11 + 10 = 21.0 + input_t1 = torch.tensor([[[11.0]]]) + out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) + + self.assertTrue( + torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}" + ) + + def test_mag_cache_reset(self): + """Test that state resets correctly after num_inference_steps.""" + model = DummyTransformer() + config = MagCacheConfig( + threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0]) + ) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + input_t = torch.ones(1, 1, 1) + + model(input_t) # Step 0 + model(input_t) # Step 1 (Skipped) + + # Step 2 (Reset -> Step 0) -> Should Compute + # Input 2.0 -> Output 8.0 + input_t2 = torch.tensor([[[2.0]]]) + output_t2 = model(input_t2) + + self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly") + + def test_mag_cache_calibration(self): + """Test that calibration mode records ratios.""" + model = DummyTransformer() + config = MagCacheConfig(num_inference_steps=2, calibrate=True) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + # HeadInput = 10. Output = 40. Residual = 30. + # Ratio 0 is placeholder 1.0 + model(torch.tensor([[[10.0]]])) + + # Check intermediate state + ratios = self._get_calibration_data(model) + self.assertEqual(len(ratios), 1) + self.assertEqual(ratios[0], 1.0) + + # Step 1 + # HeadInput = 10. Output = 40. Residual = 30. + # PrevResidual = 30. CurrResidual = 30. + # Ratio = 30/30 = 1.0 + model(torch.tensor([[[10.0]]])) + + # Verify it computes fully (no skip) + # If it skipped, output would be 41.0. It should be 40.0 + # Actually in test setup, input is same (10.0) so output 40.0. + # Let's ensure list is empty after reset (end of step 1) + ratios_after = self._get_calibration_data(model) + self.assertEqual(ratios_after, []) diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 91f63c4b56c4..78ef4ce151be 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -114,23 +116,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in AuraFlow.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index fa57b4c9c2f9..7bd54b77ca35 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -147,26 +149,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py index 30eb8fbb6367..e8ee6e7a7db6 100644 --- a/tests/lora/test_lora_layers_cogview4.py +++ b/tests/lora/test_lora_layers_cogview4.py @@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -162,23 +164,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in CogView4.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in CogView4.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py index 4ae189aceb66..d970b7d7847f 100644 --- a/tests/lora/test_lora_layers_flux2.py +++ b/tests/lora/test_lora_layers_flux2.py @@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers" denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -146,23 +148,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in Flux2.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Flux2.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index cfd5d3146a91..e59bc5662fe1 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "text_encoder_2", ) + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -172,26 +174,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @nightly @require_torch_accelerator diff --git a/tests/lora/test_lora_layers_ltx2.py b/tests/lora/test_lora_layers_ltx2.py new file mode 100644 index 000000000000..0a4b14454f5b --- /dev/null +++ b/tests/lora/test_lora_layers_ltx2.py @@ -0,0 +1,271 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder +from diffusers.utils.import_utils import is_peft_available + +from ..testing_utils import floats_tensor, require_peft_backend + + +if is_peft_available(): + from peft import LoraConfig + + +sys.path.append(".") + +from .utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = LTX2Pipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + + transformer_kwargs = { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "num_layers": 1, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 32, + "rope_double_precision": False, + "rope_type": "split", + } + transformer_cls = LTX2VideoTransformer3DModel + + vae_kwargs = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "block_out_channels": (8,), + "decoder_block_out_channels": (8,), + "layers_per_block": (1,), + "decoder_layers_per_block": (1, 1), + "spatio_temporal_scaling": (True,), + "decoder_spatio_temporal_scaling": (True,), + "decoder_inject_noise": (False, False), + "downsample_type": ("spatial",), + "upsample_residual": (False,), + "upsample_factor": (1,), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + } + vae_cls = AutoencoderKLLTX2Video + + audio_vae_kwargs = { + "base_channels": 4, + "output_channels": 2, + "ch_mult": (1,), + "num_res_blocks": 1, + "attn_resolutions": None, + "in_channels": 2, + "resolution": 32, + "latent_channels": 2, + "norm_type": "pixel", + "causality_axis": "height", + "dropout": 0.0, + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "is_causal": True, + "mel_bins": 8, + } + audio_vae_cls = AutoencoderKLLTX2Audio + + vocoder_kwargs = { + "in_channels": 16, # output_channels * mel_bins = 2 * 8 + "hidden_channels": 32, + "out_channels": 2, + "upsample_kernel_sizes": [4, 4], + "upsample_factors": [2, 2], + "resnet_kernel_sizes": [3], + "resnet_dilations": [[1, 3, 5]], + "leaky_relu_negative_slope": 0.1, + "output_sampling_rate": 16000, + } + vocoder_cls = LTX2Vocoder + + connectors_kwargs = { + "caption_channels": 32, # Will be set dynamically from text_encoder + "text_proj_in_factor": 2, # Will be set dynamically from text_encoder + "video_connector_num_attention_heads": 4, + "video_connector_attention_head_dim": 8, + "video_connector_num_layers": 1, + "video_connector_num_learnable_registers": None, + "audio_connector_num_attention_heads": 4, + "audio_connector_attention_head_dim": 8, + "audio_connector_num_layers": 1, + "audio_connector_num_learnable_registers": None, + "connector_rope_base_seq_len": 32, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_temporal_positioning": False, + "rope_type": "split", + } + connectors_cls = LTX2TextConnectors + + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-gemma3" + text_encoder_cls, text_encoder_id = ( + Gemma3ForConditionalGeneration, + "hf-internal-testing/tiny-gemma3", + ) + + denoiser_target_modules = ["to_q", "to_k", "to_out.0"] + + supports_text_encoder_loras = False + + @property + def output_shape(self): + return (1, 5, 32, 32, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 16 + num_channels = 4 + num_frames = 5 + num_latent_frames = 2 + latent_height = 8 + latent_width = 8 + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width)) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "a robot dancing", + "num_frames": num_frames, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "frame_rate": 25.0, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): + # Override to instantiate LTX2-specific components (connectors, audio_vae, vocoder) + torch.manual_seed(0) + text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) + + # Update caption_channels and text_proj_in_factor based on text_encoder config + transformer_kwargs = self.transformer_kwargs.copy() + transformer_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size + + connectors_kwargs = self.connectors_kwargs.copy() + connectors_kwargs["caption_channels"] = text_encoder.config.text_config.hidden_size + connectors_kwargs["text_proj_in_factor"] = text_encoder.config.text_config.num_hidden_layers + 1 + + torch.manual_seed(0) + transformer = self.transformer_cls(**transformer_kwargs) + + torch.manual_seed(0) + vae = self.vae_cls(**self.vae_kwargs) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = self.audio_vae_cls(**self.audio_vae_kwargs) + + torch.manual_seed(0) + vocoder = self.vocoder_cls(**self.vocoder_kwargs) + + torch.manual_seed(0) + connectors = self.connectors_cls(**connectors_kwargs) + + if scheduler_cls is None: + scheduler_cls = self.scheduler_cls + scheduler = scheduler_cls(**self.scheduler_kwargs) + + rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha + + text_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.text_encoder_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + + pipeline_components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return pipeline_components, text_lora_config, denoiser_lora_config + + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3) + + def test_simple_inference_with_text_denoiser_lora_unfused(self): + super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) + + @unittest.skip("Not supported in LTX2.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in LTX2.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in LTX2.") + def test_modify_padding_mode(self): + pass diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 6ab51a5e513f..095e5b577cf0 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -125,23 +127,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in LTXVideo.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in LTXVideo.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py index 0417b05b33a1..da032229a785 100644 --- a/tests/lora/test_lora_layers_lumina2.py +++ b/tests/lora/test_lora_layers_lumina2.py @@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 4, 4, 3) @@ -113,26 +115,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Lumina2.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py index 7be81273db77..ee8254112924 100644 --- a/tests/lora/test_lora_layers_mochi.py +++ b/tests/lora/test_lora_layers_mochi.py @@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 7, 16, 16, 3) @@ -117,26 +119,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Mochi.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skip("Not supported in CogVideoX.") def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): pass diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py index 51de2f8e20e1..73fd026a670c 100644 --- a/tests/lora/test_lora_layers_qwenimage.py +++ b/tests/lora/test_lora_layers_qwenimage.py @@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ) denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 8, 8, 3) @@ -107,23 +109,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in Qwen Image.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Qwen Image.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py index a860b7b44f2c..97bf5cbba920 100644 --- a/tests/lora/test_lora_layers_sana.py +++ b/tests/lora/test_lora_layers_sana.py @@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma" text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers" + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -117,26 +119,6 @@ def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in SANA.") - def test_simple_inference_with_text_lora_save_load(self): - pass - @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment") def test_layerwise_casting_inference_denoiser(self): return super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py index 5734509b410f..5ae16ab4b9da 100644 --- a/tests/lora/test_lora_layers_wan.py +++ b/tests/lora/test_lora_layers_wan.py @@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 32, 32, 3) @@ -121,23 +123,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in Wan.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index ab1f57bfc9da..c8acaea9bef0 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -139,26 +141,6 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se def test_modify_padding_mode(self): pass - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in Wan VACE.") - def test_simple_inference_with_text_lora_save_load(self): - pass - def test_layerwise_casting_inference_denoiser(self): super().test_layerwise_casting_inference_denoiser() diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py index 35d1389d9612..8432ea56a6fa 100644 --- a/tests/lora/test_lora_layers_z_image.py +++ b/tests/lora/test_lora_layers_z_image.py @@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + supports_text_encoder_loras = False + @property def output_shape(self): return (1, 32, 32, 3) @@ -263,23 +265,3 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in ZImage.") def test_modify_padding_mode(self): pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_partial_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_and_scale(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_fused(self): - pass - - @unittest.skip("Text encoder LoRA is not supported in ZImage.") - def test_simple_inference_with_text_lora_save_load(self): - pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 5fae6cac0a7f..efa49b9f4838 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests: tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, "" tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, "" tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, "" + supports_text_encoder_loras = True unet_kwargs = None transformer_cls = None @@ -333,6 +334,9 @@ def test_simple_inference_with_text_lora(self): Tests a simple inference with lora attached on the text encoder and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -457,6 +461,9 @@ def test_simple_inference_with_text_lora_and_scale(self): Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -494,6 +501,9 @@ def test_simple_inference_with_text_lora_fused(self): Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -555,6 +565,9 @@ def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) @@ -593,6 +606,9 @@ def test_simple_inference_with_partial_text_lora(self): with different ranks and some adapters removed and makes sure it works as expected """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, _, _ = self.get_dummy_components() # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( @@ -651,6 +667,9 @@ def test_simple_inference_save_pretrained_with_text_lora(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ + if not self.supports_text_encoder_loras: + pytest.skip("Skipping test as text encoder LoRAs are not currently supported.") + components, text_lora_config, _ = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device) diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py new file mode 100644 index 000000000000..ce93dfb42afe --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_kl_ltx2_audio.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLLTX2Audio + +from ...testing_utils import ( + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Audio + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 2, # stereo, + "output_channels": 2, + "latent_channels": 4, + "base_channels": 16, + "ch_mult": (1, 2, 4), + "resolution": 16, + "attn_resolutions": None, + "num_res_blocks": 2, + "norm_type": "pixel", + "causality_axis": "height", + "mid_block_add_attention": False, + "sample_rate": 16000, + "mel_hop_length": 160, + "mel_bins": 16, + "is_causal": True, + "double_z": True, + } + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 2 + num_frames = 8 + num_mel_bins = 16 + + spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device) + + input_dict = {"sample": spectrogram} + return input_dict + + @property + def input_shape(self): + return (2, 5, 16) + + @property + def output_shape(self): + return (2, 5, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + # Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE + def test_output(self): + super().test_output(expected_output_shape=(2, 2, 5, 16)) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py new file mode 100644 index 000000000000..146241361a82 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -0,0 +1,103 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import AutoencoderKLLTX2Video + +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) +from ..test_modeling_common import ModelTesterMixin +from .testing_utils import AutoencoderTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTX2Video + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (16, 32, 64), + "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, True, True), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (False, False, False, False), + "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + # Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros` + "decoder_spatial_padding_mode": "zeros", + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + input_dict = {"sample": image} + return input_dict + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTX2VideoEncoder3d", + "LTX2VideoDecoder3d", + "LTX2VideoDownBlock3D", + "LTX2VideoMidBlock3d", + "LTX2VideoUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index ad5a6ba48010..b9dfe932335c 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1791,7 +1791,6 @@ def run_forward(model): return model(**inputs_dict)[0] model = self.model_class(**init_dict) - model.to(torch_device) output_without_group_offloading = run_forward(model) output_without_group_offloading = normalize_output(output_without_group_offloading) @@ -1916,6 +1915,9 @@ def _run_forward(model, inputs_dict): offload_to_disk_path=tmpdir, offload_type=offload_type, num_blocks_per_group=num_blocks_per_group, + block_modules=model._group_offload_block_modules + if hasattr(model, "_group_offload_block_modules") + else None, ) if not is_correct: if extra_files: diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py new file mode 100644 index 000000000000..ea076b3ec774 --- /dev/null +++ b/tests/models/testing_utils/__init__.py @@ -0,0 +1,79 @@ +from .attention import AttentionTesterMixin +from .cache import ( + CacheTesterMixin, + FasterCacheConfigMixin, + FasterCacheTesterMixin, + FirstBlockCacheConfigMixin, + FirstBlockCacheTesterMixin, + PyramidAttentionBroadcastConfigMixin, + PyramidAttentionBroadcastTesterMixin, +) +from .common import BaseModelTesterConfig, ModelTesterMixin +from .compile import TorchCompileTesterMixin +from .ip_adapter import IPAdapterTesterMixin +from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin +from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin +from .parallelism import ContextParallelTesterMixin +from .quantization import ( + BitsAndBytesCompileTesterMixin, + BitsAndBytesConfigMixin, + BitsAndBytesTesterMixin, + GGUFCompileTesterMixin, + GGUFConfigMixin, + GGUFTesterMixin, + ModelOptCompileTesterMixin, + ModelOptConfigMixin, + ModelOptTesterMixin, + QuantizationCompileTesterMixin, + QuantizationTesterMixin, + QuantoCompileTesterMixin, + QuantoConfigMixin, + QuantoTesterMixin, + TorchAoCompileTesterMixin, + TorchAoConfigMixin, + TorchAoTesterMixin, +) +from .single_file import SingleFileTesterMixin +from .training import TrainingTesterMixin + + +__all__ = [ + "AttentionTesterMixin", + "BaseModelTesterConfig", + "BitsAndBytesCompileTesterMixin", + "BitsAndBytesConfigMixin", + "BitsAndBytesTesterMixin", + "CacheTesterMixin", + "ContextParallelTesterMixin", + "CPUOffloadTesterMixin", + "FasterCacheConfigMixin", + "FasterCacheTesterMixin", + "FirstBlockCacheConfigMixin", + "FirstBlockCacheTesterMixin", + "GGUFCompileTesterMixin", + "GGUFConfigMixin", + "GGUFTesterMixin", + "GroupOffloadTesterMixin", + "IPAdapterTesterMixin", + "LayerwiseCastingTesterMixin", + "LoraHotSwappingForModelTesterMixin", + "LoraTesterMixin", + "MemoryTesterMixin", + "ModelOptCompileTesterMixin", + "ModelOptConfigMixin", + "ModelOptTesterMixin", + "ModelTesterMixin", + "PyramidAttentionBroadcastConfigMixin", + "PyramidAttentionBroadcastTesterMixin", + "QuantizationCompileTesterMixin", + "QuantizationTesterMixin", + "QuantoCompileTesterMixin", + "QuantoConfigMixin", + "QuantoTesterMixin", + "SingleFileTesterMixin", + "TorchAoCompileTesterMixin", + "TorchAoConfigMixin", + "TorchAoTesterMixin", + "TorchCompileTesterMixin", + "TrainingTesterMixin", +] diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py new file mode 100644 index 000000000000..134b3fa33bfe --- /dev/null +++ b/tests/models/testing_utils/attention.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import pytest +import torch + +from diffusers.models.attention import AttentionModuleMixin +from diffusers.models.attention_processor import ( + AttnProcessor, +) + +from ...testing_utils import ( + assert_tensors_close, + backend_empty_cache, + is_attention, + torch_device, +) + + +@is_attention +class AttentionTesterMixin: + """ + Mixin class for testing attention processor and module functionality on models. + + Tests functionality from AttentionModuleMixin including: + - Attention processor management (set/get) + - QKV projection fusion/unfusion + - Attention backends (XFormers, NPU, etc.) + + Expected from config mixin: + - model_class: The model class to test + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: attention + Use `pytest -m "not attention"` to skip these tests + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + @torch.no_grad() + def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + if not hasattr(model, "fuse_qkv_projections"): + pytest.skip("Model does not support QKV projection fusion.") + + output_before_fusion = model(**inputs_dict, return_dict=False)[0] + + model.fuse_qkv_projections() + + has_fused_projections = False + for module in model.modules(): + if isinstance(module, AttentionModuleMixin): + if hasattr(module, "to_qkv") or hasattr(module, "to_kv"): + has_fused_projections = True + assert module.fused_projections, "fused_projections flag should be True" + break + + if has_fused_projections: + output_after_fusion = model(**inputs_dict, return_dict=False)[0] + + assert_tensors_close( + output_before_fusion, + output_after_fusion, + atol=atol, + rtol=rtol, + msg="Output should not change after fusing projections", + ) + + model.unfuse_qkv_projections() + + for module in model.modules(): + if isinstance(module, AttentionModuleMixin): + assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing" + assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing" + assert not module.fused_projections, "fused_projections flag should be False" + + output_after_unfusion = model(**inputs_dict, return_dict=False)[0] + + assert_tensors_close( + output_before_fusion, + output_after_unfusion, + atol=atol, + rtol=rtol, + msg="Output should match original after unfusing projections", + ) + + def test_get_set_processor(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + # Check if model has attention processors + if not hasattr(model, "attn_processors"): + pytest.skip("Model does not have attention processors.") + + # Test getting processors + processors = model.attn_processors + assert isinstance(processors, dict), "attn_processors should return a dict" + assert len(processors) > 0, "Model should have at least one attention processor" + + # Test that all processors can be retrieved via get_processor + for module in model.modules(): + if isinstance(module, AttentionModuleMixin): + processor = module.get_processor() + assert processor is not None, "get_processor should return a processor" + + # Test setting a new processor + new_processor = AttnProcessor() + module.set_processor(new_processor) + retrieved_processor = module.get_processor() + assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set" + + def test_attention_processor_dict(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + if not hasattr(model, "set_attn_processor"): + pytest.skip("Model does not support setting attention processors.") + + # Get current processors + current_processors = model.attn_processors + + # Create a dict of new processors + new_processors = {key: AttnProcessor() for key in current_processors.keys()} + + # Set processors using dict + model.set_attn_processor(new_processors) + + # Verify all processors were set + updated_processors = model.attn_processors + for key in current_processors.keys(): + assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor" + + def test_attention_processor_count_mismatch_raises_error(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + model.to(torch_device) + + if not hasattr(model, "set_attn_processor"): + pytest.skip("Model does not support setting attention processors.") + + # Get current processors + current_processors = model.attn_processors + + # Create a dict with wrong number of processors + wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()} + + # Verify error is raised + with pytest.raises(ValueError) as exc_info: + model.set_attn_processor(wrong_processors) + + assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py new file mode 100644 index 000000000000..f1c2ecba88a7 --- /dev/null +++ b/tests/models/testing_utils/cache.py @@ -0,0 +1,556 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import pytest +import torch + +from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig +from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK +from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK +from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK +from diffusers.models.cache_utils import CacheMixin + +from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device + + +def require_cache_mixin(func): + """Decorator to skip tests if model doesn't use CacheMixin.""" + + def wrapper(self, *args, **kwargs): + if not issubclass(self.model_class, CacheMixin): + pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.") + return func(self, *args, **kwargs) + + return wrapper + + +class CacheTesterMixin: + """ + Base mixin class providing common test implementations for cache testing. + + Cache-specific mixins should: + 1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin) + 2. Inherit from this mixin + 3. Define the cache config to use for tests + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods in test classes: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional overrides: + - cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states") + """ + + @property + def cache_input_key(self): + return "hidden_states" + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def _get_cache_config(self): + """ + Get the cache config for testing. + Should be implemented by subclasses. + """ + raise NotImplementedError("Subclass must implement _get_cache_config") + + def _get_hook_names(self): + """ + Get the hook names to check for this cache type. + Should be implemented by subclasses. + Returns a list of hook name strings. + """ + raise NotImplementedError("Subclass must implement _get_hook_names") + + def _test_cache_enable_disable_state(self): + """Test that cache enable/disable updates the is_cache_enabled state correctly.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + # Initially cache should not be enabled + assert not model.is_cache_enabled, "Cache should not be enabled initially." + + config = self._get_cache_config() + + # Enable cache + model.enable_cache(config) + assert model.is_cache_enabled, "Cache should be enabled after enable_cache()." + + # Disable cache + model.disable_cache() + assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()." + + def _test_cache_double_enable_raises_error(self): + """Test that enabling cache twice raises an error.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + config = self._get_cache_config() + + model.enable_cache(config) + + # Trying to enable again should raise ValueError + with pytest.raises(ValueError, match="Caching has already been enabled"): + model.enable_cache(config) + + # Cleanup + model.disable_cache() + + def _test_cache_hooks_registered(self): + """Test that cache hooks are properly registered and removed.""" + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + config = self._get_cache_config() + hook_names = self._get_hook_names() + + model.enable_cache(config) + + # Check that at least one hook was registered + hook_count = 0 + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + for hook_name in hook_names: + hook = module._diffusers_hook.get_hook(hook_name) + if hook is not None: + hook_count += 1 + + assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}" + + # Disable and verify hooks are removed + model.disable_cache() + + hook_count_after = 0 + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + for hook_name in hook_names: + hook = module._diffusers_hook.get_hook(hook_name) + if hook is not None: + hook_count_after += 1 + + assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()." + + @torch.no_grad() + def _test_cache_inference(self): + """Test that model can run inference with cache enabled.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + + model.enable_cache(config) + + # First pass populates the cache + _ = model(**inputs_dict, return_dict=False)[0] + + # Create modified inputs for second pass (vary input tensor to simulate denoising) + inputs_dict_step2 = inputs_dict.copy() + if self.cache_input_key in inputs_dict_step2: + inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like( + inputs_dict_step2[self.cache_input_key] + ) + + # Second pass uses cached attention with different inputs (produces approximated output) + output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] + + assert output_with_cache is not None, "Model output should not be None with cache enabled." + assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled." + + # Run same inputs without cache to compare + model.disable_cache() + output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] + + # Cached output should be different from non-cached output (due to approximation) + assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), ( + "Cached output should be different from non-cached output due to cache approximation." + ) + + @torch.no_grad() + def _test_cache_context_manager(self, atol=1e-5, rtol=0): + """Test the cache_context context manager properly isolates cache state.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + # Run inference in first context + with model.cache_context("context_1"): + output_ctx1 = model(**inputs_dict, return_dict=False)[0] + + # Run same inference in second context (cache should be reset) + with model.cache_context("context_2"): + output_ctx2 = model(**inputs_dict, return_dict=False)[0] + + # Both contexts should produce the same output (first pass in each) + assert_tensors_close( + output_ctx1, + output_ctx2, + atol=atol, + rtol=rtol, + msg="First pass in different cache contexts should produce the same output.", + ) + + model.disable_cache() + + @torch.no_grad() + def _test_reset_stateful_cache(self): + """Test that _reset_stateful_cache resets the cache state.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + + model.enable_cache(config) + + _ = model(**inputs_dict, return_dict=False)[0] + + model._reset_stateful_cache() + + model.disable_cache() + + +@is_cache +class PyramidAttentionBroadcastConfigMixin: + """ + Base mixin providing PyramidAttentionBroadcast cache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default PAB config - can be overridden by subclasses + PAB_CONFIG = { + "spatial_attention_block_skip_range": 2, + } + + # Store timestep for callback (must be within default range (100, 800) for skipping to trigger) + _current_timestep = 500 + + def _get_cache_config(self): + config_kwargs = self.PAB_CONFIG.copy() + config_kwargs["current_timestep_callback"] = lambda: self._current_timestep + return PyramidAttentionBroadcastConfig(**config_kwargs) + + def _get_hook_names(self): + return [_PYRAMID_ATTENTION_BROADCAST_HOOK] + + +@is_cache +class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin): + """ + Mixin class for testing PyramidAttentionBroadcast caching on models. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @require_cache_mixin + def test_pab_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_pab_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_pab_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_pab_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_pab_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_pab_reset_stateful_cache(self): + self._test_reset_stateful_cache() + + +@is_cache +class FirstBlockCacheConfigMixin: + """ + Base mixin providing FirstBlockCache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default FBC config - can be overridden by subclasses + # Higher threshold makes FBC more aggressive about caching (skips more often) + FBC_CONFIG = { + "threshold": 1.0, + } + + def _get_cache_config(self): + return FirstBlockCacheConfig(**self.FBC_CONFIG) + + def _get_hook_names(self): + return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK] + + +@is_cache +class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin): + """ + Mixin class for testing FirstBlockCache on models. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @torch.no_grad() + def _test_cache_inference(self): + """Test that model can run inference with FBC cache enabled (requires cache_context).""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + # FBC requires cache_context to be set for inference + with model.cache_context("fbc_test"): + # First pass populates the cache + _ = model(**inputs_dict, return_dict=False)[0] + + # Create modified inputs for second pass + inputs_dict_step2 = inputs_dict.copy() + if self.cache_input_key in inputs_dict_step2: + inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like( + inputs_dict_step2[self.cache_input_key] + ) + + # Second pass - FBC should skip remaining blocks and use cached residuals + output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] + + assert output_with_cache is not None, "Model output should not be None with cache enabled." + assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled." + + # Run same inputs without cache to compare + model.disable_cache() + output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] + + # Cached output should be different from non-cached output (due to approximation) + assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), ( + "Cached output should be different from non-cached output due to cache approximation." + ) + + @torch.no_grad() + def _test_reset_stateful_cache(self): + """Test that _reset_stateful_cache resets the FBC cache state (requires cache_context).""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + with model.cache_context("fbc_test"): + _ = model(**inputs_dict, return_dict=False)[0] + + model._reset_stateful_cache() + + model.disable_cache() + + @require_cache_mixin + def test_fbc_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_fbc_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_fbc_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_fbc_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_fbc_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_fbc_reset_stateful_cache(self): + self._test_reset_stateful_cache() + + +@is_cache +class FasterCacheConfigMixin: + """ + Base mixin providing FasterCache config. + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + """ + + # Default FasterCache config - can be overridden by subclasses + FASTER_CACHE_CONFIG = { + "spatial_attention_block_skip_range": 2, + "spatial_attention_timestep_skip_range": (-1, 901), + "tensor_format": "BCHW", + } + + def _get_cache_config(self, current_timestep_callback=None): + config_kwargs = self.FASTER_CACHE_CONFIG.copy() + if current_timestep_callback is None: + current_timestep_callback = lambda: 1000 # noqa: E731 + config_kwargs["current_timestep_callback"] = current_timestep_callback + return FasterCacheConfig(**config_kwargs) + + def _get_hook_names(self): + return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK] + + +@is_cache +class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin): + """ + Mixin class for testing FasterCache on models. + + Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling + and timestep management. Inference tests are skipped at model level - FasterCache should + be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline). + + Expected class attributes: + - model_class: The model class to test (must use CacheMixin) + + Expected methods to be implemented by subclasses: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cache + Use `pytest -m "not cache"` to skip these tests + """ + + @torch.no_grad() + def _test_cache_inference(self): + """Test that model can run inference with FasterCache enabled.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + current_timestep = [1000] + config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0]) + + model.enable_cache(config) + + # First pass with timestep outside skip range - computes and populates cache + current_timestep[0] = 1000 + _ = model(**inputs_dict, return_dict=False)[0] + + # Move timestep inside skip range so subsequent passes use cache + current_timestep[0] = 500 + + # Create modified inputs for second pass + inputs_dict_step2 = inputs_dict.copy() + if self.cache_input_key in inputs_dict_step2: + inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like( + inputs_dict_step2[self.cache_input_key] + ) + + # Second pass uses cached attention with different inputs + output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] + + assert output_with_cache is not None, "Model output should not be None with cache enabled." + assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled." + + # Run same inputs without cache to compare + model.disable_cache() + output_without_cache = model(**inputs_dict_step2, return_dict=False)[0] + + # Cached output should be different from non-cached output (due to approximation) + assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), ( + "Cached output should be different from non-cached output due to cache approximation." + ) + + @torch.no_grad() + def _test_reset_stateful_cache(self): + """Test that _reset_stateful_cache resets the FasterCache state.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + config = self._get_cache_config() + model.enable_cache(config) + + _ = model(**inputs_dict, return_dict=False)[0] + + model._reset_stateful_cache() + + model.disable_cache() + + @require_cache_mixin + def test_faster_cache_enable_disable_state(self): + self._test_cache_enable_disable_state() + + @require_cache_mixin + def test_faster_cache_double_enable_raises_error(self): + self._test_cache_double_enable_raises_error() + + @require_cache_mixin + def test_faster_cache_hooks_registered(self): + self._test_cache_hooks_registered() + + @require_cache_mixin + def test_faster_cache_inference(self): + self._test_cache_inference() + + @require_cache_mixin + def test_faster_cache_context_manager(self): + self._test_cache_context_manager() + + @require_cache_mixin + def test_faster_cache_reset_stateful_cache(self): + self._test_reset_stateful_cache() diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py new file mode 100644 index 000000000000..145a6fc27f35 --- /dev/null +++ b/tests/models/testing_utils/common.py @@ -0,0 +1,666 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections import defaultdict +from typing import Any, Dict, Optional, Type + +import pytest +import torch +import torch.nn as nn +from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size + +from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging +from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator + +from ...testing_utils import assert_tensors_close, torch_device + + +def named_persistent_module_tensors( + module: nn.Module, + recurse: bool = False, +): + """ + A helper function that gathers all the tensors (parameters + persistent buffers) of a given module. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + """ + yield from module.named_parameters(recurse=recurse) + + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + # Get parent by splitting on dots and traversing the model + parent = module + if "." in name: + parent_name = name.rsplit(".", 1)[0] + for part in parent_name.split("."): + parent = getattr(parent, part) + name = name.split(".")[-1] + if name not in parent._non_persistent_buffers_set: + yield named_buffer + + +def compute_module_persistent_sizes( + model: nn.Module, + dtype: str | torch.device | None = None, + special_dtypes: dict[str, str | torch.device] | None = None, +): + """ + Compute the size of each submodule of a given model (parameters + persistent buffers). + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + module_list = [] + + module_list = named_persistent_module_tensors(model, recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + + +def calculate_expected_num_shards(index_map_path): + """ + Calculate expected number of shards from index file. + + Args: + index_map_path: Path to the sharded checkpoint index file + + Returns: + int: Expected number of shards + """ + with open(index_map_path) as f: + weight_map_dict = json.load(f)["weight_map"] + first_key = list(weight_map_dict.keys())[0] + weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors + expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0]) + return expected_num_shards + + +def check_device_map_is_respected(model, device_map): + for param_name, param in model.named_parameters(): + # Find device in device_map + while len(param_name) > 0 and param_name not in device_map: + param_name = ".".join(param_name.split(".")[:-1]) + if param_name not in device_map: + raise ValueError("device map is incomplete, it does not contain any device for `param_name`.") + + param_device = device_map[param_name] + if param_device in ["cpu", "disk"]: + assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}" + else: + assert param.device == torch.device(param_device), ( + f"Expected device {param_device} for {param_name}, got {param.device}" + ) + + +def cast_inputs_to_dtype(inputs, current_dtype, target_dtype): + if torch.is_tensor(inputs): + return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs + if isinstance(inputs, dict): + return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()} + if isinstance(inputs, list): + return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs] + + return inputs + + +class BaseModelTesterConfig: + """ + Base class defining the configuration interface for model testing. + + This class defines the contract that all model test classes must implement. + It provides a consistent interface for accessing model configuration, initialization + parameters, and test inputs across all testing mixins. + + Required properties (must be implemented by subclasses): + - model_class: The model class to test + + Optional properties (can be overridden, have sensible defaults): + - pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None) + - pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {}) + - output_shape: Expected output shape for output validation tests (default: None) + - model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7]) + + Required methods (must be implemented by subclasses): + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Example usage: + class MyModelTestConfig(BaseModelTesterConfig): + @property + def model_class(self): + return MyModel + + @property + def pretrained_model_name_or_path(self): + return "org/my-model" + + @property + def output_shape(self): + return (1, 3, 32, 32) + + def get_init_dict(self): + return {"in_channels": 3, "out_channels": 3} + + def get_dummy_inputs(self): + return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)} + + class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin): + pass + """ + + # ==================== Required Properties ==================== + + @property + def model_class(self) -> Type[nn.Module]: + """The model class to test. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `model_class` property.") + + # ==================== Optional Properties ==================== + + @property + def pretrained_model_name_or_path(self) -> Optional[str]: + """Hub repository ID for the pretrained model (used for quantization and hub tests).""" + return None + + @property + def pretrained_model_kwargs(self) -> Dict[str, Any]: + """Additional kwargs to pass to from_pretrained (e.g., subfolder, variant).""" + return {} + + @property + def output_shape(self) -> Optional[tuple]: + """Expected output shape for output validation tests.""" + return None + + @property + def model_split_percents(self) -> list: + """Percentages for model parallelism tests.""" + return [0.9] + + # ==================== Required Methods ==================== + + def get_init_dict(self) -> Dict[str, Any]: + """ + Returns dict of arguments to initialize the model. + + Returns: + Dict[str, Any]: Initialization arguments for the model constructor. + + Example: + return { + "in_channels": 3, + "out_channels": 3, + "sample_size": 32, + } + """ + raise NotImplementedError("Subclasses must implement `get_init_dict()`.") + + def get_dummy_inputs(self) -> Dict[str, Any]: + """ + Returns dict of inputs to pass to the model forward pass. + + Returns: + Dict[str, Any]: Input tensors/values for model.forward(). + + Example: + return { + "sample": torch.randn(1, 3, 32, 32, device=torch_device), + "timestep": torch.tensor([1], device=torch_device), + } + """ + raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.") + + +class ModelTesterMixin: + """ + Base mixin class for model testing with common test methods. + + This mixin expects the test class to also inherit from BaseModelTesterConfig + (or implement its interface) which provides: + - model_class: The model class to test + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Example: + class MyModelTestConfig(BaseModelTesterConfig): + model_class = MyModel + def get_init_dict(self): ... + def get_dummy_inputs(self): ... + + class TestMyModel(MyModelTestConfig, ModelTesterMixin): + pass + """ + + @torch.no_grad() + def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5): + torch.manual_seed(0) + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + model.save_pretrained(tmp_path) + new_model = self.model_class.from_pretrained(tmp_path) + new_model.to(torch_device) + + for param_name in model.state_dict().keys(): + param_1 = model.state_dict()[param_name] + param_2 = new_model.state_dict()[param_name] + assert param_1.shape == param_2.shape, ( + f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}" + ) + + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + + assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") + + @torch.no_grad() + def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + model.save_pretrained(tmp_path, variant="fp16") + new_model = self.model_class.from_pretrained(tmp_path, variant="fp16") + + with pytest.raises(OSError) as exc_info: + self.model_class.from_pretrained(tmp_path) + + assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value) + + new_model.to(torch_device) + + image = model(**self.get_dummy_inputs(), return_dict=False)[0] + new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0] + + assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.") + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"]) + def test_from_save_pretrained_dtype(self, tmp_path, dtype): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + if torch_device == "mps" and dtype == torch.bfloat16: + pytest.skip(reason=f"{dtype} is not supported on {torch_device}") + + model.to(dtype) + model.save_pretrained(tmp_path) + new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype) + assert new_model.dtype == dtype + if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None: + # When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None + new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype) + assert new_model.dtype == dtype + + @torch.no_grad() + def test_determinism(self, atol=1e-5, rtol=0): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + first = model(**self.get_dummy_inputs(), return_dict=False)[0] + second = model(**self.get_dummy_inputs(), return_dict=False)[0] + + first_flat = first.flatten() + second_flat = second.flatten() + mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat)) + first_filtered = first_flat[mask] + second_filtered = second_flat[mask] + + assert_tensors_close( + first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic" + ) + + @torch.no_grad() + def test_output(self, expected_output_shape=None): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + inputs_dict = self.get_dummy_inputs() + output = model(**inputs_dict, return_dict=False)[0] + + assert output is not None, "Model output is None" + assert output[0].shape == expected_output_shape or self.output_shape, ( + f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}" + ) + + @torch.no_grad() + def test_outputs_equivalence(self, atol=1e-5, rtol=0): + def set_nan_tensor_to_zero(t): + device = t.device + if device.type == "mps": + t = t.to("cpu") + t[t != t] = 0 + return t.to(device) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (list, tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + assert_tensors_close( + set_nan_tensor_to_zero(tuple_object), + set_nan_tensor_to_zero(dict_object), + atol=atol, + rtol=rtol, + msg="Tuple and dict output are not equal", + ) + + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + model.eval() + + outputs_dict = model(**self.get_dummy_inputs()) + outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False) + + recursive_check(outputs_tuple, outputs_dict) + + def test_getattr_is_correct(self, caplog): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + # save some things to test + model.dummy_attribute = 5 + model.register_to_config(test_attribute=5) + + logger_name = "diffusers.models.modeling_utils" + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() + assert hasattr(model, "dummy_attribute") + assert getattr(model, "dummy_attribute") == 5 + assert model.dummy_attribute == 5 + + # no warning should be thrown + assert caplog.text == "" + + with caplog.at_level(logging.WARNING, logger=logger_name): + caplog.clear() + assert hasattr(model, "save_pretrained") + fn = model.save_pretrained + fn_1 = getattr(model, "save_pretrained") + + assert fn == fn_1 + + # no warning should be thrown + assert caplog.text == "" + + # warning should be thrown for config attributes accessed directly + with pytest.warns(FutureWarning): + assert model.test_attribute == 5 + + with pytest.warns(FutureWarning): + assert getattr(model, "test_attribute") == 5 + + with pytest.raises(AttributeError) as error: + model.does_not_exist + + assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" + + @require_accelerator + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be used with an accelerator", + ) + def test_keep_in_fp32_modules(self): + model = self.model_class(**self.get_init_dict()) + fp32_modules = model._keep_in_fp32_modules + + if fp32_modules is None or len(fp32_modules) == 0: + pytest.skip("Model does not have _keep_in_fp32_modules defined.") + + # Test with float16 + model.to(torch_device) + model.to(torch.float16) + + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}" + else: + assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}" + + @require_accelerator + @pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], + reason="float16 and bfloat16 can only be use for inference with an accelerator", + ) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) + @torch.no_grad() + def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): + model = self.model_class(**self.get_init_dict()) + model.to(torch_device) + fp32_modules = model._keep_in_fp32_modules or [] + + model.to(dtype).save_pretrained(tmp_path) + model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device) + + for name, param in model_loaded.named_parameters(): + if fp32_modules and any( + module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules + ): + assert param.data.dtype == torch.float32 + else: + assert param.data.dtype == dtype + + inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype) + output = model(**inputs, return_dict=False)[0] + output_loaded = model_loaded(**inputs, return_dict=False)[0] + + self._check_dtype_inference_output(output, output_loaded, dtype) + + def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0): + """Check dtype inference output with configurable tolerance.""" + assert_tensors_close( + output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}" + ) + + @require_accelerator + @torch.no_grad() + def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0): + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict, return_dict=False)[0] + + model_size = compute_module_persistent_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small + + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" + + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) + + new_model = self.model_class.from_pretrained(tmp_path).eval() + new_model = new_model.to(torch_device) + + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new, return_dict=False)[0] + + assert_tensors_close( + base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load" + ) + + @require_accelerator + @torch.no_grad() + def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0): + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict, return_dict=False)[0] + + model_size = compute_module_persistent_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small + variant = "fp16" + + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant) + + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + assert os.path.exists(os.path.join(tmp_path, index_filename)), ( + f"Variant index file {index_filename} should exist" + ) + + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) + + new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval() + new_model = new_model.to(torch_device) + + torch.manual_seed(0) + inputs_dict_new = self.get_dummy_inputs() + new_output = new_model(**inputs_dict_new, return_dict=False)[0] + + assert_tensors_close( + base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load" + ) + + @torch.no_grad() + def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0): + from diffusers.utils import constants + + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + model = model.to(torch_device) + + base_output = model(**inputs_dict, return_dict=False)[0] + + model_size = compute_module_persistent_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small + + # Save original values to restore after test + original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING + original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None) + + try: + model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB") + assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist" + + # Check if the right number of shards exists + expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)) + actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")]) + assert actual_num_shards == expected_num_shards, ( + f"Expected {expected_num_shards} shards, got {actual_num_shards}" + ) + + # Load without parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = False + model_sequential = self.model_class.from_pretrained(tmp_path).eval() + model_sequential = model_sequential.to(torch_device) + + # Load with parallel loading + constants.HF_ENABLE_PARALLEL_LOADING = True + constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2 + + torch.manual_seed(0) + model_parallel = self.model_class.from_pretrained(tmp_path).eval() + model_parallel = model_parallel.to(torch_device) + + torch.manual_seed(0) + inputs_dict_parallel = self.get_dummy_inputs() + output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0] + + assert_tensors_close( + base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading" + ) + + finally: + # Restore original values + constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading + if original_parallel_workers is not None: + constants.HF_PARALLEL_WORKERS = original_parallel_workers + + @require_torch_multi_accelerator + @torch.no_grad() + def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0): + if self.model_class._no_split_modules is None: + pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") + + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict, return_dict=False)[0] + + model_size = compute_module_sizes(model)[""] + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + + model.cpu().save_pretrained(tmp_path) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory) + # Making sure part of the model will be on GPU 0 and GPU 1 + assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs" + + check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict, return_dict=False)[0] + + assert_tensors_close( + base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism" + ) diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py new file mode 100644 index 000000000000..950d4d5d1fa5 --- /dev/null +++ b/tests/models/testing_utils/compile.py @@ -0,0 +1,166 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os + +import pytest +import torch + +from ...testing_utils import ( + backend_empty_cache, + is_torch_compile, + require_accelerator, + require_torch_version_greater, + torch_device, +) + + +@is_torch_compile +@require_accelerator +@require_torch_version_greater("2.7.1") +class TorchCompileTesterMixin: + """ + Mixin class for testing torch.compile functionality on models. + + Expected from config mixin: + - model_class: The model class to test + + Optional properties: + - different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None) + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: compile + Use `pytest -m "not compile"` to skip these tests + """ + + @property + def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None: + """Optional list of (height, width) tuples for dynamic shape testing.""" + return None + + def setup_method(self): + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + @torch.no_grad() + def test_torch_compile_recompilation_and_graph_break(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model = torch.compile(model, fullgraph=True) + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + @torch.no_grad() + def test_torch_compile_repeated_blocks(self): + if self.model_class._repeated_blocks is None: + pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.") + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model.compile_repeated_blocks(fullgraph=True) + + recompile_limit = 1 + if self.model_class.__name__ == "UNet2DConditionModel": + recompile_limit = 2 + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(recompile_limit=recompile_limit), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + @torch.no_grad() + def test_compile_with_group_offloading(self): + if not self.model_class._supports_group_offloading: + pytest.skip("Model does not support group offloading.") + + torch._dynamo.config.cache_size_limit = 10000 + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.eval() + + group_offload_kwargs = { + "onload_device": torch_device, + "offload_device": "cpu", + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + "non_blocking": True, + } + model.enable_group_offload(**group_offload_kwargs) + model.compile() + + _ = model(**inputs_dict) + _ = model(**inputs_dict) + + @torch.no_grad() + def test_compile_on_different_shapes(self): + if self.different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + torch.fx.experimental._config.use_duck_shape = False + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model = torch.compile(model, fullgraph=True, dynamic=True) + + for height, width in self.different_shapes_for_compilation: + with torch._dynamo.config.patch(error_on_recompile=True): + inputs_dict = self.get_dummy_inputs(height=height, width=width) + _ = model(**inputs_dict) + + @torch.no_grad() + def test_compile_works_with_aot(self, tmp_path): + from torch._inductor.package import load_package + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict).to(torch_device) + exported_model = torch.export.export(model, args=(), kwargs=inputs_dict) + + package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2") + _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) + assert os.path.exists(package_path), f"Package file not created at {package_path}" + loaded_binary = load_package(package_path, run_single_threaded=True) + + model.forward = loaded_binary + + _ = model(**inputs_dict) + _ = model(**inputs_dict) diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py new file mode 100644 index 000000000000..632019c87499 --- /dev/null +++ b/tests/models/testing_utils/ip_adapter.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import pytest +import torch + +from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device + + +def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool: + """ + Check if IP Adapter processors are correctly set in the model. + + Args: + model: The model to check + + Returns: + bool: True if IP Adapter is correctly set, False otherwise + """ + for module in model.attn_processors.values(): + if isinstance(module, processor_cls): + return True + return False + + +@is_ip_adapter +class IPAdapterTesterMixin: + """ + Mixin class for testing IP Adapter functionality on models. + + Expected from config mixin: + - model_class: The model class to test + + Required properties (must be implemented by subclasses): + - ip_adapter_processor_cls: The IP Adapter processor class to use + + Required methods (must be implemented by subclasses): + - create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing + - modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: ip_adapter + Use `pytest -m "not ip_adapter"` to skip these tests + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + @property + def ip_adapter_processor_cls(self): + """IP Adapter processor class to use for testing. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.") + + def create_ip_adapter_state_dict(self, model): + raise NotImplementedError("child class must implement method to create IPAdapter State Dict") + + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + raise NotImplementedError("child class must implement method to create IPAdapter model inputs") + + @torch.no_grad() + def test_load_ip_adapter(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + torch.manual_seed(0) + output_no_adapter = model(**inputs_dict, return_dict=False)[0] + + ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) + + model._load_ip_adapter_weights([ip_adapter_state_dict]) + assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), ( + "IP Adapter processors not set correctly" + ) + + inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy()) + outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0] + + assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), ( + "Output should differ with IP Adapter enabled" + ) + + @pytest.mark.skip( + reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring" + ) + def test_ip_adapter_scale(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) + model._load_ip_adapter_weights([ip_adapter_state_dict]) + + inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy()) + + # Test scale = 0.0 (no effect) + model.set_ip_adapter_scale(0.0) + torch.manual_seed(0) + output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0] + + # Test scale = 1.0 (full effect) + model.set_ip_adapter_scale(1.0) + torch.manual_seed(0) + output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0] + + # Outputs should differ with different scales + assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), ( + "Output should differ with different IP Adapter scales" + ) + + @pytest.mark.skip( + reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring" + ) + def test_unload_ip_adapter(self): + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + # Save original processors + original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} + + # Create and load IP adapter + ip_adapter_state_dict = self.create_ip_adapter_state_dict(model) + model._load_ip_adapter_weights([ip_adapter_state_dict]) + + assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set" + + # Unload IP adapter + model.unload_ip_adapter() + + assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), ( + "IP Adapter should be unloaded" + ) + + # Verify processors are restored + current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()} + assert original_processors == current_processors, "Processors should be restored after unload" diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py new file mode 100644 index 000000000000..994aaed55ca7 --- /dev/null +++ b/tests/models/testing_utils/lora.py @@ -0,0 +1,555 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import os +import re + +import pytest +import safetensors.torch +import torch +import torch.nn as nn + +from diffusers.utils.import_utils import is_peft_available +from diffusers.utils.testing_utils import check_if_dicts_are_equal + +from ...testing_utils import ( + assert_tensors_close, + backend_empty_cache, + is_lora, + is_torch_compile, + require_peft_backend, + require_peft_version_greater, + require_torch_accelerator, + require_torch_version_greater, + torch_device, +) + + +if is_peft_available(): + from diffusers.loaders.peft import PeftAdapterMixin + + +def check_if_lora_correctly_set(model) -> bool: + """ + Check if LoRA layers are correctly set in the model. + + Args: + model: The model to check + + Returns: + bool: True if LoRA is correctly set, False otherwise + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return True + return False + + +@is_lora +@require_peft_backend +class LoraTesterMixin: + """ + Mixin class for testing LoRA/PEFT functionality on models. + + Expected from config mixin: + - model_class: The model class to test + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: lora + Use `pytest -m "not lora"` to skip these tests + """ + + def setup_method(self): + if not issubclass(self.model_class, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") + + @torch.no_grad() + def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4): + from peft import LoraConfig + from peft.utils import get_peft_model_state_dict + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + torch.manual_seed(0) + output_no_lora = model(**inputs_dict, return_dict=False)[0] + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + torch.manual_seed(0) + outputs_with_lora = model(**inputs_dict, return_dict=False)[0] + + assert not torch.allclose(output_no_lora, outputs_with_lora, atol=atol, rtol=rtol), ( + "Output should differ with LoRA enabled" + ) + + model.save_lora_adapter(tmp_path) + assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), ( + "LoRA weights file not created" + ) + + state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") + + for k in state_dict_loaded: + loaded_v = state_dict_loaded[k] + retrieved_v = state_dict_retrieved[k].to(loaded_v.device) + assert_tensors_close(loaded_v, retrieved_v, atol=atol, rtol=rtol, msg=f"Mismatch in LoRA weight {k}") + + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload" + + torch.manual_seed(0) + outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] + + assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=atol, rtol=rtol), ( + "Output should differ with LoRA enabled" + ) + assert_tensors_close( + outputs_with_lora, + outputs_with_lora_2, + atol=atol, + rtol=rtol, + msg="Outputs should match before and after save/load", + ) + + def test_lora_wrong_adapter_name_raises_error(self, tmp_path): + from peft import LoraConfig + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + wrong_name = "foo" + with pytest.raises(ValueError) as exc_info: + model.save_lora_adapter(tmp_path, adapter_name=wrong_name) + + assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value) + + def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False): + from peft import LoraConfig + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + metadata = model.peft_config["default"].to_dict() + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + model.save_lora_adapter(tmp_path) + model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + parsed_metadata = model.peft_config["default_0"].to_dict() + check_if_dicts_are_equal(metadata, parsed_metadata) + + def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path): + from peft import LoraConfig + + from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + model.save_lora_adapter(tmp_path) + model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file), "LoRA weights file not created" + + # Perturb the metadata in the state dict + loaded_state_dict = safetensors.torch.load_file(model_file) + metadata = {"format": "pt"} + lora_adapter_metadata = denoiser_lora_config.to_dict() + lora_adapter_metadata.update({"foo": 1, "bar": 2}) + for key, value in lora_adapter_metadata.items(): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA should be unloaded" + + with pytest.raises(TypeError) as exc_info: + model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True) + assert "`LoraConfig` class could not be instantiated" in str(exc_info.value) + + +@is_lora +@is_torch_compile +@require_peft_backend +@require_peft_version_greater("0.14.0") +@require_torch_version_greater("2.7.1") +@require_torch_accelerator +class LoraHotSwappingForModelTesterMixin: + """ + Mixin class for testing LoRA hot swapping functionality on models. + + Test that hotswapping does not result in recompilation on the model directly. + We're not extensively testing the hotswapping functionality since it is implemented in PEFT + and is extensively tested there. The goal of this test is specifically to ensure that + hotswapping with diffusers does not require recompilation. + + See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252 + for the analogous PEFT test. + + Expected from config mixin: + - model_class: The model class to test + + Optional properties: + - different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None) + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest marks: lora, torch_compile + Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests + """ + + @property + def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None: + """Optional list of (height, width) tuples for dynamic compilation tests.""" + return None + + def setup_method(self): + if not issubclass(self.model_class, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).") + + def teardown_method(self): + # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, + # there will be recompilation errors, as torch caches the model when run in the same process. + torch.compiler.reset() + gc.collect() + backend_empty_cache(torch_device) + + def _get_lora_config(self, lora_rank, lora_alpha, target_modules): + from peft import LoraConfig + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=target_modules, + init_lora_weights=False, + use_dora=False, + ) + return lora_config + + def _get_linear_module_name_other_than_attn(self, model): + linear_names = [ + name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name + ] + return linear_names[0] + + def _check_model_hotswap( + self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None, atol=5e-3, rtol=5e-3 + ): + """ + Check that hotswapping works on a model. + + Steps: + - create 2 LoRA adapters and save them + - load the first adapter + - hotswap the second adapter + - check that the outputs are correct + - optionally compile the model + - optionally check if recompilations happen on different shapes + + Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would + fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is + fine. + """ + different_shapes = self.different_shapes_for_compilation + # create 2 adapters with different ranks and alphas + torch.manual_seed(0) + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict).to(torch_device) + + alpha0, alpha1 = rank0, rank1 + max_rank = max([rank0, rank1]) + if target_modules1 is None: + target_modules1 = target_modules0[:] + lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0) + lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1) + + model.add_adapter(lora_config0, adapter_name="adapter0") + with torch.inference_mode(): + torch.manual_seed(0) + output0_before = model(**inputs_dict)["sample"] + + model.add_adapter(lora_config1, adapter_name="adapter1") + model.set_adapter("adapter1") + with torch.inference_mode(): + torch.manual_seed(0) + output1_before = model(**inputs_dict)["sample"] + + # sanity checks: + assert not torch.allclose(output0_before, output1_before, atol=atol, rtol=rtol) + assert not (output0_before == 0).all() + assert not (output1_before == 0).all() + + # save the adapter checkpoints + model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0") + model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1") + del model + + # load the first adapter + torch.manual_seed(0) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + + if do_compile or (rank0 != rank1): + # no need to prepare if the model is not compiled or if the ranks are identical + model.enable_lora_hotswap(target_rank=max_rank) + + file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors") + file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors") + model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) + + if do_compile: + model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None) + + with torch.inference_mode(): + # additionally check if dynamic compilation works. + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output0_after = model(**inputs_dict)["sample"] + assert_tensors_close( + output0_before, output0_after, atol=atol, rtol=rtol, msg="Output mismatch after loading adapter0" + ) + + # hotswap the 2nd adapter + model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) + + # we need to call forward to potentially trigger recompilation + with torch.inference_mode(): + if different_shapes is not None: + for height, width in different_shapes: + new_inputs_dict = self.prepare_dummy_input(height=height, width=width) + _ = model(**new_inputs_dict) + else: + output1_after = model(**inputs_dict)["sample"] + assert_tensors_close( + output1_before, + output1_after, + atol=atol, + rtol=rtol, + msg="Output mismatch after hotswapping to adapter1", + ) + + # check error when not passing valid adapter name + name = "does-not-exist" + msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" + with pytest.raises(ValueError, match=re.escape(msg)): + model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_model(self, tmp_path, rank0, rank1): + self._check_model_hotswap( + tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"] + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + pytest.skip("Test only applies to UNet.") + + # It's important to add this context to raise an error on recompilation + target_modules = ["conv", "conv1", "conv2"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1): + if "unet" not in self.model_class.__name__.lower(): + pytest.skip("Test only applies to UNet.") + + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "conv"] + with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache(): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1): + # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping + # with `torch.compile()` for models that have both linear and conv layers. In this test, we check + # if we can target a linear layer from the transformer blocks and another linear layer from non-attention + # block. + target_modules = ["to_q"] + init_dict = self.get_init_dict() + model = self.model_class(**init_dict) + + target_modules.append(self._get_linear_module_name_other_than_attn(model)) + del model + + # It's important to add this context to raise an error on recompilation + with torch._dynamo.config.patch(error_on_recompile=True): + self._check_model_hotswap( + tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules + ) + + def test_enable_lora_hotswap_called_after_adapter_added_raises(self): + # ensure that enable_lora_hotswap is called before loading the first adapter + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + + msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") + with pytest.raises(RuntimeError, match=msg): + model.enable_lora_hotswap(target_rank=32) + + def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog): + # ensure that enable_lora_hotswap is called before loading the first adapter + import logging + + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = ( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="warn") + assert any(msg in record.message for record in caplog.records) + + def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog): + # check possibility to ignore the error/warning + import logging + + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + with caplog.at_level(logging.WARNING): + model.enable_lora_hotswap(target_rank=32, check_compiled="ignore") + assert len(caplog.records) == 0 + + def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): + # check that wrong argument value raises an error + lora_config = self._get_lora_config(8, 8, target_modules=["to_q"]) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) + model.add_adapter(lora_config) + msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") + with pytest.raises(ValueError, match=msg): + model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") + + def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog): + # check the error and log + import logging + + # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers + target_modules0 = ["to_q"] + target_modules1 = ["to_q", "to_k"] + with pytest.raises(RuntimeError): # peft raises RuntimeError + with caplog.at_level(logging.ERROR): + self._check_model_hotswap( + tmp_path, + do_compile=True, + rank0=8, + rank1=8, + target_modules0=target_modules0, + target_modules1=target_modules1, + ) + assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records) + + @pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)]) + @require_torch_version_greater("2.7.1") + def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1): + different_shapes_for_compilation = self.different_shapes_for_compilation + if different_shapes_for_compilation is None: + pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") + # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic + # variable to represent input sizes that are the same. For more details, + # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790). + torch.fx.experimental._config.use_duck_shape = False + + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self._check_model_hotswap( + tmp_path, + do_compile=True, + rank0=rank0, + rank1=rank1, + target_modules0=target_modules, + ) diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py new file mode 100644 index 000000000000..68480bddd39c --- /dev/null +++ b/tests/models/testing_utils/memory.py @@ -0,0 +1,498 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import glob +import inspect +from functools import wraps + +import pytest +import torch +from accelerate.utils.modeling import compute_module_sizes + +from diffusers.utils.testing_utils import _check_safetensors_serialization +from diffusers.utils.torch_utils import get_torch_cuda_device_capability + +from ...testing_utils import ( + assert_tensors_close, + backend_empty_cache, + backend_max_memory_allocated, + backend_reset_peak_memory_stats, + backend_synchronize, + is_cpu_offload, + is_group_offload, + is_memory, + require_accelerator, + torch_device, +) +from .common import cast_inputs_to_dtype, check_device_map_is_respected + + +def require_offload_support(func): + """ + Decorator to skip tests if model doesn't support offloading (requires _no_split_modules). + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.model_class._no_split_modules is None: + pytest.skip("Test not supported for this model as `_no_split_modules` is not set.") + return func(self, *args, **kwargs) + + return wrapper + + +def require_group_offload_support(func): + """ + Decorator to skip tests if model doesn't support group offloading. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.model_class._supports_group_offloading: + pytest.skip("Model does not support group offloading.") + return func(self, *args, **kwargs) + + return wrapper + + +@is_cpu_offload +class CPUOffloadTesterMixin: + """ + Mixin class for testing CPU offloading functionality. + + Expected from config mixin: + - model_class: The model class to test + + Optional properties: + - model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7]) + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: cpu_offload + Use `pytest -m "not cpu_offload"` to skip these tests + """ + + @property + def model_split_percents(self) -> list[float]: + """List of percentages for splitting model across devices during offloading tests.""" + return [0.5, 0.7] + + @require_offload_support + @torch.no_grad() + def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0): + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + model.cpu().save_pretrained(str(tmp_path)) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU" + + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + assert_tensors_close( + base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading" + ) + + @require_offload_support + @torch.no_grad() + def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0): + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + max_size = int(self.model_split_percents[0] * model_size) + # Force disk offload by setting very small CPU memory + max_memory = {0: max_size, "cpu": int(0.1 * max_size)} + + model.cpu().save_pretrained(str(tmp_path), safe_serialization=False) + # This errors out because it's missing an offload folder + with pytest.raises(ValueError): + new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory) + + new_model = self.model_class.from_pretrained( + str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path) + ) + + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + assert_tensors_close( + base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading" + ) + + @require_offload_support + @torch.no_grad() + def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0): + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**config).eval() + + model = model.to(torch_device) + + torch.manual_seed(0) + base_output = model(**inputs_dict) + + model_size = compute_module_sizes(model)[""] + model.cpu().save_pretrained(str(tmp_path)) + + max_size = int(self.model_split_percents[0] * model_size) + max_memory = {0: max_size, "cpu": max_size} + new_model = self.model_class.from_pretrained( + str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory + ) + + check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict) + + assert_tensors_close( + base_output[0], + new_output[0], + atol=atol, + rtol=rtol, + msg="Output should match with disk offloading (safetensors)", + ) + + +@is_group_offload +class GroupOffloadTesterMixin: + """ + Mixin class for testing group offloading functionality. + + Expected from config mixin: + - model_class: The model class to test + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: group_offload + Use `pytest -m "not group_offload"` to skip these tests + """ + + @require_group_offload_support + @pytest.mark.parametrize("record_stream", [False, True]) + def test_group_offloading(self, record_stream, atol=1e-5, rtol=0): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + torch.manual_seed(0) + + @torch.no_grad() + def run_forward(model): + assert all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in model.modules() + if hasattr(module, "_diffusers_hook") + ), "Group offloading hook should be set" + model.eval() + return model(**inputs_dict)[0] + + model = self.model_class(**init_dict) + + model.to(torch_device) + output_without_group_offloading = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) + output_with_group_offloading1 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) + output_with_group_offloading2 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload(torch_device, offload_type="leaf_level") + output_with_group_offloading3 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.enable_group_offload( + torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream + ) + output_with_group_offloading4 = run_forward(model) + + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading1, + atol=atol, + rtol=rtol, + msg="Output should match with block-level offloading", + ) + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading2, + atol=atol, + rtol=rtol, + msg="Output should match with non-blocking block-level offloading", + ) + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading3, + atol=atol, + rtol=rtol, + msg="Output should match with leaf-level offloading", + ) + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading4, + atol=atol, + rtol=rtol, + msg="Output should match with leaf-level offloading with stream", + ) + + @require_group_offload_support + @pytest.mark.parametrize("record_stream", [False, True]) + @pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"]) + @torch.no_grad() + def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type): + torch.manual_seed(0) + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + + model.to(torch_device) + model.eval() + _ = model(**inputs_dict)[0] + + torch.manual_seed(0) + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + storage_dtype, compute_dtype = torch.float16, torch.float32 + inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**init_dict) + model.eval() + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} + model.enable_group_offload( + torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs + ) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + _ = model(**inputs_dict)[0] + + @require_group_offload_support + @pytest.mark.parametrize("record_stream", [False, True]) + @pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"]) + @torch.no_grad() + @torch.inference_mode() + def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0): + def _has_generator_arg(model): + sig = inspect.signature(model.forward) + params = sig.parameters + return "generator" in params + + def _run_forward(model, inputs_dict): + accepts_generator = _has_generator_arg(model) + if accepts_generator: + inputs_dict["generator"] = torch.manual_seed(0) + torch.manual_seed(0) + return model(**inputs_dict)[0] + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + torch.manual_seed(0) + model = self.model_class(**init_dict) + + model.eval() + model.to(torch_device) + output_without_group_offloading = _run_forward(model, inputs_dict) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.eval() + + num_blocks_per_group = None if offload_type == "leaf_level" else 1 + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group} + tmpdir = str(tmp_path) + model.enable_group_offload( + torch_device, + offload_type=offload_type, + offload_to_disk_path=tmpdir, + use_stream=True, + record_stream=record_stream, + **additional_kwargs, + ) + has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") + assert has_safetensors, "No safetensors found in the directory." + + # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic + # in nature. So, skip it. + if offload_type != "leaf_level": + is_correct, extra_files, missing_files = _check_safetensors_serialization( + module=model, + offload_to_disk_path=tmpdir, + offload_type=offload_type, + num_blocks_per_group=num_blocks_per_group, + ) + if not is_correct: + if extra_files: + raise ValueError(f"Found extra files: {', '.join(extra_files)}") + elif missing_files: + raise ValueError(f"Following files are missing: {', '.join(missing_files)}") + + output_with_group_offloading = _run_forward(model, inputs_dict) + assert_tensors_close( + output_without_group_offloading, + output_with_group_offloading, + atol=atol, + rtol=rtol, + msg="Output should match with disk-based group offloading", + ) + + +class LayerwiseCastingTesterMixin: + """ + Mixin class for testing layerwise dtype casting for memory optimization. + + Expected from config mixin: + - model_class: The model class to test + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + @torch.no_grad() + def test_layerwise_casting_memory(self): + MB_TOLERANCE = 0.2 + LEAST_COMPUTE_CAPABILITY = 8.0 + + def reset_memory_stats(): + gc.collect() + backend_synchronize(torch_device) + backend_empty_cache(torch_device) + backend_reset_peak_memory_stats(torch_device) + + def get_memory_usage(storage_dtype, compute_dtype): + torch.manual_seed(0) + config = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**config).eval() + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + + reset_memory_stats() + model(**inputs_dict) + model_memory_footprint = model.get_memory_footprint() + peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2 + + return model_memory_footprint, peak_inference_memory_allocated_mb + + fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32) + fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32) + fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage( + torch.float8_e4m3fn, torch.bfloat16 + ) + + compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None + assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, ( + "Memory footprint should decrease with lower precision storage" + ) + + # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. + # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it. + if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY: + assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, ( + "Peak memory should be lower with bf16 compute on newer GPUs" + ) + + # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few + # bytes. This only happens for some models, so we allow a small tolerance. + # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32. + assert ( + fp8_e4m3_fp32_max_memory < fp32_max_memory + or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE + ), "Peak memory should be lower or within tolerance with fp8 storage" + + def test_layerwise_casting_training(self): + def test_fn(storage_dtype, compute_dtype): + if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16: + pytest.skip("Skipping test because CPU doesn't go well with bfloat16.") + + model = self.model_class(**self.get_init_dict()) + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + model.train() + + inputs_dict = self.get_dummy_inputs() + inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype) + with torch.amp.autocast(device_type=torch.device(torch_device).type): + output = model(**inputs_dict, return_dict=False)[0] + + input_tensor = inputs_dict[self.main_input_name] + noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) + noise = cast_inputs_to_dtype(noise, torch.float32, compute_dtype) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() + + test_fn(torch.float16, torch.float32) + test_fn(torch.float8_e4m3fn, torch.float32) + test_fn(torch.float8_e5m2, torch.float32) + test_fn(torch.float8_e4m3fn, torch.bfloat16) + + +@is_memory +@require_accelerator +class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin): + """ + Combined mixin class for all memory optimization tests including CPU/disk offloading, + group offloading, and layerwise dtype casting. + + This mixin inherits from: + - CPUOffloadTesterMixin: CPU and disk offloading tests + - GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level) + - LayerwiseCastingTesterMixin: Layerwise dtype casting tests + + Expected from config mixin: + - model_class: The model class to test + + Optional properties: + - model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7]) + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: memory + Use `pytest -m "not memory"` to skip these tests + """ diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py new file mode 100644 index 000000000000..e05b36799e66 --- /dev/null +++ b/tests/models/testing_utils/parallelism.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from diffusers.models._modeling_parallel import ContextParallelConfig + +from ...testing_utils import ( + is_context_parallel, + require_torch_multi_accelerator, +) + + +def _find_free_port(): + """Find a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + +def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict): + """Worker function for context parallel testing.""" + try: + # Set up distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + # Initialize process group + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + # Set device for this process + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + # Create model + model = model_class(**init_dict) + model.to(device) + model.eval() + + # Move inputs to device + inputs_on_device = {} + for key, value in inputs_dict.items(): + if isinstance(value, torch.Tensor): + inputs_on_device[key] = value.to(device) + else: + inputs_on_device[key] = value + + # Enable context parallelism + cp_config = ContextParallelConfig(**cp_dict) + model.enable_parallelism(config=cp_config) + + # Run forward pass + with torch.no_grad(): + output = model(**inputs_on_device, return_dict=False)[0] + + # Only rank 0 reports results + if rank == 0: + return_dict["status"] = "success" + return_dict["output_shape"] = list(output.shape) + + except Exception as e: + if rank == 0: + return_dict["status"] = "error" + return_dict["error"] = str(e) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +@is_context_parallel +@require_torch_multi_accelerator +class ContextParallelTesterMixin: + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_inference(self, cp_type): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + # Move all tensors to CPU for multiprocessing + inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + cp_dict = {cp_type: world_size} + + # Find a free port for distributed communication + master_port = _find_free_port() + + # Use multiprocessing manager for cross-process communication + manager = mp.Manager() + return_dict = manager.dict() + + # Spawn worker processes + mp.spawn( + _context_parallel_worker, + args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict), + nprocs=world_size, + join=True, + ) + + assert return_dict.get("status") == "success", ( + f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" + ) diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py new file mode 100644 index 000000000000..f27e912766e5 --- /dev/null +++ b/tests/models/testing_utils/quantization.py @@ -0,0 +1,1368 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import pytest +import torch + +from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig +from diffusers.utils.import_utils import ( + is_bitsandbytes_available, + is_gguf_available, + is_nvidia_modelopt_available, + is_optimum_quanto_available, + is_torchao_available, + is_torchao_version, +) + +from ...testing_utils import ( + backend_empty_cache, + backend_max_memory_allocated, + backend_reset_peak_memory_stats, + is_bitsandbytes, + is_gguf, + is_modelopt, + is_quantization, + is_quanto, + is_torch_compile, + is_torchao, + require_accelerate, + require_accelerator, + require_bitsandbytes_version_greater, + require_gguf_version_greater_or_equal, + require_modelopt_version_greater_or_equal, + require_quanto, + require_torchao_version_greater_or_equal, + torch_device, +) + + +if is_nvidia_modelopt_available(): + import modelopt.torch.quantization as mtq + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + +if is_optimum_quanto_available(): + from optimum.quanto import QLinear + +if is_gguf_available(): + pass + +if is_torchao_available(): + if is_torchao_version(">=", "0.9.0"): + pass + + +class LoRALayer(torch.nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only. + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: torch.nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = torch.nn.Sequential( + torch.nn.Linear(module.in_features, rank, bias=False), + torch.nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + torch.nn.init.normal_(self.adapter[0].weight, std=small_std) + torch.nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +@is_quantization +@require_accelerator +class QuantizationTesterMixin: + """ + Base mixin class providing common test implementations for quantization testing. + + Backend-specific mixins should: + 1. Implement _create_quantized_model(config_kwargs) + 2. Implement _verify_if_layer_quantized(name, module, config_kwargs) + 3. Define their config dict (e.g., BNB_CONFIGS, QUANTO_WEIGHT_TYPES, etc.) + 4. Use @pytest.mark.parametrize to create tests that call the common test methods below + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods in test classes: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + """ + Create a quantized model with the given config kwargs. + + Args: + config_kwargs: Quantization config parameters + **extra_kwargs: Additional kwargs to pass to from_pretrained (e.g., device_map, offload_folder) + """ + raise NotImplementedError("Subclass must implement _create_quantized_model") + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + raise NotImplementedError("Subclass must implement _verify_if_layer_quantized") + + def _is_module_quantized(self, module): + """ + Check if a module is quantized. Returns True if quantized, False otherwise. + Default implementation tries _verify_if_layer_quantized and catches exceptions. + Subclasses can override for more efficient checking. + """ + try: + self._verify_if_layer_quantized("", module, {}) + return True + except (AssertionError, AttributeError): + return False + + def _load_unquantized_model(self): + kwargs = getattr(self, "pretrained_model_kwargs", {}) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _test_quantization_num_parameters(self, config_kwargs): + model = self._load_unquantized_model() + num_params = model.num_parameters() + + model_quantized = self._create_quantized_model(config_kwargs) + num_params_quantized = model_quantized.num_parameters() + + assert num_params == num_params_quantized, ( + f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}" + ) + + def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): + model = self._load_unquantized_model() + mem = model.get_memory_footprint() + + model_quantized = self._create_quantized_model(config_kwargs) + mem_quantized = model_quantized.get_memory_footprint() + + ratio = mem / mem_quantized + assert ratio >= expected_memory_reduction, ( + f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + ) + + @torch.no_grad() + def _test_quantization_inference(self, config_kwargs): + model_quantized = self._create_quantized_model(config_kwargs) + model_quantized.to(torch_device) + + # Get model dtype from first parameter + model_dtype = next(model_quantized.parameters()).dtype + + inputs = self.get_dummy_inputs() + # Cast inputs to model dtype + inputs = { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + output = model_quantized(**inputs, return_dict=False)[0] + + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + def _test_quantization_dtype_assignment(self, config_kwargs): + model = self._create_quantized_model(config_kwargs) + + with pytest.raises(ValueError): + model.to(torch.float16) + + with pytest.raises(ValueError): + device_0 = f"{torch_device}:0" + model.to(device=device_0, dtype=torch.float16) + + with pytest.raises(ValueError): + model.float() + + with pytest.raises(ValueError): + model.half() + + model.to(torch_device) + + @torch.no_grad() + def _test_quantization_lora_inference(self, config_kwargs): + try: + from peft import LoraConfig + except ImportError: + pytest.skip("peft is not available") + + from diffusers.loaders.peft import PeftAdapterMixin + + if not issubclass(self.model_class, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__})") + + model = self._create_quantized_model(config_kwargs) + + lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + ) + model.add_adapter(lora_config) + + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + + assert output is not None, "Model output is None with LoRA" + assert not torch.isnan(output).any(), "Model output contains NaN with LoRA" + + @torch.no_grad() + def _test_quantization_serialization(self, config_kwargs, tmp_path): + model = self._create_quantized_model(config_kwargs) + + model.save_pretrained(str(tmp_path), safe_serialization=True) + + model_loaded = self.model_class.from_pretrained(str(tmp_path)) + + inputs = self.get_dummy_inputs() + output = model_loaded(**inputs, return_dict=False)[0] + assert not torch.isnan(output).any(), "Loaded model output contains NaN" + + def _test_quantized_layers(self, config_kwargs): + model_fp = self._load_unquantized_model() + num_linear_layers = sum(1 for module in model_fp.modules() if isinstance(module, torch.nn.Linear)) + + model_quantized = self._create_quantized_model(config_kwargs) + + num_fp32_modules = 0 + if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules: + for name, module in model_quantized.named_modules(): + if isinstance(module, torch.nn.Linear): + if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules): + num_fp32_modules += 1 + + expected_quantized_layers = num_linear_layers - num_fp32_modules + + num_quantized_layers = 0 + for name, module in model_quantized.named_modules(): + if isinstance(module, torch.nn.Linear): + if hasattr(model_quantized, "_keep_in_fp32_modules") and model_quantized._keep_in_fp32_modules: + if any(fp32_name in name for fp32_name in model_quantized._keep_in_fp32_modules): + continue + self._verify_if_layer_quantized(name, module, config_kwargs) + num_quantized_layers += 1 + + assert num_quantized_layers > 0, ( + f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)" + ) + assert num_quantized_layers == expected_quantized_layers, ( + f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})" + ) + + def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert): + """ + Test that modules specified in modules_to_not_convert are not quantized. + + Args: + config_kwargs: Base quantization config kwargs + modules_to_not_convert: List of module names to exclude from quantization + """ + # Create config with modules_to_not_convert + config_kwargs_with_exclusion = config_kwargs.copy() + config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert + + model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion) + + # Find a module that should NOT be quantized + found_excluded = False + for name, module in model_with_exclusion.named_modules(): + if isinstance(module, torch.nn.Linear): + # Check if this module is in the exclusion list + if any(excluded in name for excluded in modules_to_not_convert): + found_excluded = True + # This module should NOT be quantized + assert not self._is_module_quantized(module), ( + f"Module {name} should not be quantized but was found to be quantized" + ) + + assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}" + + # Find a module that SHOULD be quantized (not in exclusion list) + found_quantized = False + for name, module in model_with_exclusion.named_modules(): + if isinstance(module, torch.nn.Linear): + # Check if this module is NOT in the exclusion list + if not any(excluded in name for excluded in modules_to_not_convert): + if self._is_module_quantized(module): + found_quantized = True + break + + assert found_quantized, "No quantized layers found outside of excluded modules" + + # Compare memory footprint with fully quantized model + model_fully_quantized = self._create_quantized_model(config_kwargs) + + mem_with_exclusion = model_with_exclusion.get_memory_footprint() + mem_fully_quantized = model_fully_quantized.get_memory_footprint() + + assert mem_with_exclusion > mem_fully_quantized, ( + f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}" + ) + + @torch.no_grad() + def _test_quantization_device_map(self, config_kwargs): + """ + Test that quantized models work correctly with device_map="auto". + + Args: + config_kwargs: Base quantization config kwargs + """ + model = self._create_quantized_model(config_kwargs, device_map="auto") + + assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute" + assert model.hf_device_map is not None, "hf_device_map should not be None" + + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + @torch.no_grad() + def _test_dequantize(self, config_kwargs): + """ + Test that dequantize() converts quantized model back to standard linear layers. + + Args: + config_kwargs: Quantization config parameters + """ + model = self._create_quantized_model(config_kwargs) + model.to(torch_device) + + if not hasattr(model, "dequantize"): + pytest.skip("Model does not have dequantize method") + + model.dequantize() + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()" + + # Get model dtype from first parameter + model_dtype = next(model.parameters()).dtype + + inputs = self.get_dummy_inputs() + # Cast inputs to model dtype + inputs = { + k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs.items() + } + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None after dequantization" + assert not torch.isnan(output).any(), "Model output contains NaN after dequantization" + + def _test_quantization_training(self, config_kwargs): + """ + Test that quantized models can be used for training with LoRA-like adapters. + + This test: + 1. Freezes all model parameters + 2. Casts small parameters (e.g., layernorm) to fp32 for stability + 3. Adds LoRA adapters to attention layers + 4. Runs forward and backward passes + 5. Verifies gradients are computed correctly + + Args: + config_kwargs: Quantization config parameters + """ + model = self._create_quantized_model(config_kwargs) + + # Step 1: freeze all parameters + for param in model.parameters(): + param.requires_grad = False + if param.ndim == 1: + # cast small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters to attention layers + adapter_count = 0 + for _, module in model.named_modules(): + if "Attention" in repr(type(module)): + if hasattr(module, "to_k"): + module.to_k = LoRALayer(module.to_k, rank=4) + adapter_count += 1 + if hasattr(module, "to_q"): + module.to_q = LoRALayer(module.to_q, rank=4) + adapter_count += 1 + if hasattr(module, "to_v"): + module.to_v = LoRALayer(module.to_v, rank=4) + adapter_count += 1 + + if adapter_count == 0: + pytest.skip("No attention layers found in model for adapter training test") + + # Step 3: run forward and backward pass + inputs = self.get_dummy_inputs() + + with torch.amp.autocast(torch_device, dtype=torch.float16): + out = model(**inputs, return_dict=False)[0] + out.norm().backward() + + # Step 4: verify gradients are computed + for module in model.modules(): + if isinstance(module, LoRALayer): + assert module.adapter[1].weight.grad is not None, "LoRA adapter gradient is None" + assert module.adapter[1].weight.grad.norm().item() > 0, "LoRA adapter gradient norm is zero" + + +@is_quantization +@is_bitsandbytes +@require_accelerator +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +class BitsAndBytesConfigMixin: + """ + Base mixin providing BitsAndBytes quantization config and model creation. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + """ + + BNB_CONFIGS = { + "4bit_nf4": { + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.float16, + }, + "4bit_fp4": { + "load_in_4bit": True, + "bnb_4bit_quant_type": "fp4", + "bnb_4bit_compute_dtype": torch.float16, + }, + "8bit": { + "load_in_8bit": True, + }, + } + + BNB_EXPECTED_MEMORY_REDUCTIONS = { + "4bit_nf4": 3.0, + "4bit_fp4": 3.0, + "8bit": 1.5, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = BitsAndBytesConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params + assert module.weight.__class__ == expected_weight_class, ( + f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}" + ) + + +@is_bitsandbytes +@require_accelerator +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing BitsAndBytes quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - BNB_CONFIGS: Dict of config name -> BitsAndBytesConfig kwargs to test + + Pytest mark: bitsandbytes + Use `pytest -m "not bitsandbytes"` to skip these tests + """ + + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) + def test_bnb_quantization_num_parameters(self, config_name): + self._test_quantization_num_parameters(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) + def test_bnb_quantization_memory_footprint(self, config_name): + expected = BitsAndBytesConfigMixin.BNB_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) + self._test_quantization_memory_footprint( + BitsAndBytesConfigMixin.BNB_CONFIGS[config_name], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) + def test_bnb_quantization_inference(self, config_name): + self._test_quantization_inference(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_quantization_dtype_assignment(self, config_name): + self._test_quantization_dtype_assignment(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_quantization_lora_inference(self, config_name): + self._test_quantization_lora_inference(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_quantization_serialization(self, config_name, tmp_path): + self._test_quantization_serialization(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name], tmp_path) + + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) + def test_bnb_quantized_layers(self, config_name): + self._test_quantized_layers(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize( + "config_name", + list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()), + ) + def test_bnb_quantization_config_serialization(self, config_name): + model = self._create_quantized_model(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + assert "quantization_config" in model.config, "Missing quantization_config" + _ = model.config["quantization_config"].to_dict() + _ = model.config["quantization_config"].to_diff_dict() + _ = model.config["quantization_config"].to_json_string() + + def test_bnb_original_dtype(self): + config_name = list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys())[0] + config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS[config_name] + + model = self._create_quantized_model(config_kwargs) + + assert "_pre_quantization_dtype" in model.config, "Missing _pre_quantization_dtype" + assert model.config["_pre_quantization_dtype"] in [ + torch.float16, + torch.float32, + torch.bfloat16, + ], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}" + + @torch.no_grad() + def test_bnb_keep_modules_in_fp32(self): + if not hasattr(self.model_class, "_keep_in_fp32_modules"): + pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules") + + config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"] + + original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None) + self.model_class._keep_in_fp32_modules = ["proj_out"] + + try: + model = self._create_quantized_model(config_kwargs) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules): + assert module.weight.dtype == torch.float32, ( + f"Module {name} should be FP32 but is {module.weight.dtype}" + ) + else: + assert module.weight.dtype == torch.uint8, ( + f"Module {name} should be uint8 but is {module.weight.dtype}" + ) + + inputs = self.get_dummy_inputs() + _ = model(**inputs) + finally: + if original_fp32_modules is not None: + self.model_class._keep_in_fp32_modules = original_fp32_modules + + def test_bnb_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert( + BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude + ) + + @pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"]) + def test_bnb_device_map(self, config_name): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + def test_bnb_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]) + + def test_bnb_training(self): + """Test that quantized models can be used for training with adapters.""" + self._test_quantization_training(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]) + + +@is_quantization +@is_quanto +@require_quanto +@require_accelerate +@require_accelerator +class QuantoConfigMixin: + """ + Base mixin providing Quanto quantization config and model creation. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + """ + + QUANTO_WEIGHT_TYPES = { + "float8": {"weights_dtype": "float8"}, + "int8": {"weights_dtype": "int8"}, + "int4": {"weights_dtype": "int4"}, + "int2": {"weights_dtype": "int2"}, + } + + QUANTO_EXPECTED_MEMORY_REDUCTIONS = { + "float8": 1.5, + "int8": 1.5, + "int4": 3.0, + "int2": 7.0, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = QuantoConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + assert isinstance(module, QLinear), f"Layer {name} is not QLinear, got {type(module)}" + + def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2): + """Override to use max_memory_allocated for Quanto (get_memory_footprint doesn't reflect quantized _data).""" + # Measure unquantized model memory + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) + + model = self._load_unquantized_model() + model.to(torch_device) + mem = backend_max_memory_allocated(torch_device) + + del model + gc.collect() + backend_empty_cache(torch_device) + + # Measure quantized model memory + backend_reset_peak_memory_stats(torch_device) + + model_quantized = self._create_quantized_model(config_kwargs) + model_quantized.to(torch_device) + mem_quantized = backend_max_memory_allocated(torch_device) + + ratio = mem / mem_quantized + assert ratio >= expected_memory_reduction, ( + f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}" + ) + + +@is_quanto +@require_quanto +@require_accelerate +@require_accelerator +class QuantoTesterMixin(QuantoConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing Quanto quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - QUANTO_WEIGHT_TYPES: Dict of weight_type_name -> qtype + + Pytest mark: quanto + Use `pytest -m "not quanto"` to skip these tests + """ + + @pytest.mark.parametrize( + "weight_type_name", + list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ) + def test_quanto_quantization_num_parameters(self, weight_type_name): + self._test_quantization_num_parameters(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize( + "weight_type_name", + list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ) + def test_quanto_quantization_memory_footprint(self, weight_type_name): + expected = QuantoConfigMixin.QUANTO_EXPECTED_MEMORY_REDUCTIONS.get(weight_type_name, 1.2) + self._test_quantization_memory_footprint( + QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize( + "weight_type_name", + list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ids=list(QuantoConfigMixin.QUANTO_WEIGHT_TYPES.keys()), + ) + def test_quanto_quantization_inference(self, weight_type_name): + self._test_quantization_inference(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_quantized_layers(self, weight_type_name): + self._test_quantized_layers(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_quantization_lora_inference(self, weight_type_name): + self._test_quantization_lora_inference(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_quantization_serialization(self, weight_type_name, tmp_path): + self._test_quantization_serialization(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name], tmp_path) + + def test_quanto_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert( + QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"], modules_to_exclude + ) + + def test_quanto_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"]) + + def test_quanto_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(QuantoConfigMixin.QUANTO_WEIGHT_TYPES["int8"]) + + +@is_quantization +@is_torchao +@require_accelerator +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoConfigMixin: + """ + Base mixin providing TorchAO quantization config and model creation. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + """ + + TORCHAO_QUANT_TYPES = { + "int4wo": {"quant_type": "int4_weight_only"}, + "int8wo": {"quant_type": "int8_weight_only"}, + "int8dq": {"quant_type": "int8_dynamic_activation_int8_weight"}, + } + + TORCHAO_EXPECTED_MEMORY_REDUCTIONS = { + "int4wo": 1.8, + "int8wo": 1.5, + "int8dq": 1.5, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = TorchAoConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs["device_map"] = str(torch_device) + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" + + +# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack) +_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA") + + +@is_torchao +@require_accelerator +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing TorchAO quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - TORCHAO_QUANT_TYPES: Dict of quantization type strings to test + + Pytest mark: torchao + Use `pytest -m "not torchao"` to skip these tests + """ + + @pytest.mark.parametrize( + "quant_type", + [ + pytest.param("int4wo", marks=_int4wo_skip), + "int8wo", + "int8dq", + ], + ids=["int4wo", "int8wo", "int8dq"], + ) + def test_torchao_quantization_num_parameters(self, quant_type): + self._test_quantization_num_parameters(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize( + "quant_type", + [ + pytest.param("int4wo", marks=_int4wo_skip), + "int8wo", + "int8dq", + ], + ids=["int4wo", "int8wo", "int8dq"], + ) + def test_torchao_quantization_memory_footprint(self, quant_type): + expected = TorchAoConfigMixin.TORCHAO_EXPECTED_MEMORY_REDUCTIONS.get(quant_type, 1.2) + self._test_quantization_memory_footprint( + TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize( + "quant_type", + [ + pytest.param("int4wo", marks=_int4wo_skip), + "int8wo", + "int8dq", + ], + ids=["int4wo", "int8wo", "int8dq"], + ) + def test_torchao_quantization_inference(self, quant_type): + self._test_quantization_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_quantized_layers(self, quant_type): + self._test_quantized_layers(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_quantization_lora_inference(self, quant_type): + self._test_quantization_lora_inference(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_quantization_serialization(self, quant_type, tmp_path): + """Override to use safe_serialization=False for TorchAO (safetensors not supported).""" + config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type] + model = self._create_quantized_model(config_kwargs) + + model.save_pretrained(str(tmp_path), safe_serialization=False) + + model_loaded = self.model_class.from_pretrained(str(tmp_path), device_map=str(torch_device)) + + inputs = self.get_dummy_inputs() + output = model_loaded(**inputs, return_dict=False)[0] + assert not torch.isnan(output).any(), "Loaded model output contains NaN" + + def test_torchao_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert( + TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude + ) + + def test_torchao_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + + def test_torchao_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + + def test_torchao_training(self): + """Test that quantized models can be used for training with adapters.""" + self._test_quantization_training(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]) + + +@is_quantization +@is_gguf +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFConfigMixin: + """ + Base mixin providing GGUF quantization config and model creation. + + Expected from config mixin: + - model_class: The model class to test + + Required properties (must be implemented by subclasses): + - gguf_filename: URL or path to the GGUF file + """ + + @property + def gguf_filename(self): + """URL or path to the GGUF file. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `gguf_filename` property.") + + def _create_quantized_model(self, config_kwargs=None, **extra_kwargs): + if config_kwargs is None: + config_kwargs = {"compute_dtype": torch.bfloat16} + + config = GGUFQuantizationConfig(**config_kwargs) + kwargs = { + "quantization_config": config, + "torch_dtype": config_kwargs.get("compute_dtype", torch.bfloat16), + "device_map": str(torch_device), + } + kwargs.update(extra_kwargs) + return self.model_class.from_single_file(self.gguf_filename, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs=None): + from diffusers.quantizers.gguf.utils import GGUFParameter + + assert isinstance(module.weight, GGUFParameter), f"{name} weight is not GGUFParameter" + assert hasattr(module.weight, "quant_type"), f"{name} weight missing quant_type" + assert module.weight.dtype == torch.uint8, f"{name} weight dtype should be uint8" + + +@is_gguf +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing GGUF quantization on models. + + Expected from config mixin: + - model_class: The model class to test + + Required properties (must be implemented by subclasses): + - gguf_filename: URL or path to the GGUF file + + Expected methods from config mixin: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: gguf + Use `pytest -m "not gguf"` to skip these tests + """ + + def test_gguf_quantization_inference(self): + self._test_quantization_inference({"compute_dtype": torch.bfloat16}) + + def test_gguf_keep_modules_in_fp32(self): + if not hasattr(self.model_class, "_keep_in_fp32_modules"): + pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules") + + _keep_in_fp32_modules = self.model_class._keep_in_fp32_modules + self.model_class._keep_in_fp32_modules = ["proj_out"] + + try: + model = self._create_quantized_model() + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and name in model._keep_in_fp32_modules: + assert module.weight.dtype == torch.float32, f"Module {name} should be FP32" + finally: + self.model_class._keep_in_fp32_modules = _keep_in_fp32_modules + + def test_gguf_quantization_dtype_assignment(self): + self._test_quantization_dtype_assignment({"compute_dtype": torch.bfloat16}) + + def test_gguf_quantization_lora_inference(self): + self._test_quantization_lora_inference({"compute_dtype": torch.bfloat16}) + + def test_gguf_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize({"compute_dtype": torch.bfloat16}) + + def test_gguf_quantized_layers(self): + self._test_quantized_layers({"compute_dtype": torch.bfloat16}) + + +@is_quantization +@is_modelopt +@require_accelerator +@require_accelerate +@require_modelopt_version_greater_or_equal("0.33.1") +class ModelOptConfigMixin: + """ + Base mixin providing NVIDIA ModelOpt quantization config and model creation. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + """ + + MODELOPT_CONFIGS = { + "fp8": {"quant_type": "FP8"}, + "int8": {"quant_type": "INT8"}, + "int4": {"quant_type": "INT4"}, + } + + MODELOPT_EXPECTED_MEMORY_REDUCTIONS = { + "fp8": 1.5, + "int8": 1.5, + "int4": 3.0, + } + + def _create_quantized_model(self, config_kwargs, **extra_kwargs): + config = NVIDIAModelOptConfig(**config_kwargs) + kwargs = getattr(self, "pretrained_model_kwargs", {}).copy() + kwargs["quantization_config"] = config + kwargs["device_map"] = str(torch_device) + kwargs.update(extra_kwargs) + return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs) + + def _verify_if_layer_quantized(self, name, module, config_kwargs): + assert mtq.utils.is_quantized(module), f"Layer {name} does not have weight_quantizer attribute (not quantized)" + + +@is_modelopt +@require_accelerator +@require_accelerate +@require_modelopt_version_greater_or_equal("0.33.1") +class ModelOptTesterMixin(ModelOptConfigMixin, QuantizationTesterMixin): + """ + Mixin class for testing NVIDIA ModelOpt quantization on models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained (e.g., {"subfolder": "transformer"}) + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional class attributes: + - MODELOPT_CONFIGS: Dict of config name -> NVIDIAModelOptConfig kwargs to test + + Pytest mark: modelopt + Use `pytest -m "not modelopt"` to skip these tests + """ + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_quantization_num_parameters(self, config_name): + self._test_quantization_num_parameters(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize( + "config_name", + list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ids=list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ) + def test_modelopt_quantization_memory_footprint(self, config_name): + expected = ModelOptConfigMixin.MODELOPT_EXPECTED_MEMORY_REDUCTIONS.get(config_name, 1.2) + self._test_quantization_memory_footprint( + ModelOptConfigMixin.MODELOPT_CONFIGS[config_name], expected_memory_reduction=expected + ) + + @pytest.mark.parametrize( + "config_name", + list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ids=list(ModelOptConfigMixin.MODELOPT_CONFIGS.keys()), + ) + def test_modelopt_quantization_inference(self, config_name): + self._test_quantization_inference(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_quantization_dtype_assignment(self, config_name): + self._test_quantization_dtype_assignment(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_quantization_lora_inference(self, config_name): + self._test_quantization_lora_inference(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_quantization_serialization(self, config_name, tmp_path): + self._test_quantization_serialization(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name], tmp_path) + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_quantized_layers(self, config_name): + self._test_quantized_layers(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) + + def test_modelopt_modules_to_not_convert(self): + """Test that modules_to_not_convert parameter works correctly.""" + modules_to_exclude = getattr(self, "modules_to_not_convert_for_test", None) + if modules_to_exclude is None: + pytest.skip("modules_to_not_convert_for_test not defined for this model") + + self._test_quantization_modules_to_not_convert(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"], modules_to_exclude) + + def test_modelopt_device_map(self): + """Test that device_map='auto' works correctly with quantization.""" + self._test_quantization_device_map(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"]) + + def test_modelopt_dequantize(self): + """Test that dequantize() works correctly.""" + self._test_dequantize(ModelOptConfigMixin.MODELOPT_CONFIGS["fp8"]) + + +@is_quantization +@is_torch_compile +class QuantizationCompileTesterMixin: + """ + Base mixin class providing common test implementations for torch.compile with quantized models. + + Backend-specific compile mixins should: + 1. Inherit from their respective config mixin (e.g., BitsAndBytesConfigMixin) + 2. Inherit from this mixin + 3. Define the config to use for compile tests + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods in test classes: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + @torch.no_grad() + def _test_torch_compile(self, config_kwargs): + """ + Test that torch.compile works correctly with a quantized model. + + Args: + config_kwargs: Quantization config parameters + """ + model = self._create_quantized_model(config_kwargs) + model.to(torch_device) + model.eval() + + model = torch.compile(model, fullgraph=True) + + with torch._dynamo.config.patch(error_on_recompile=True): + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + @torch.no_grad() + def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False): + """ + Test that torch.compile works correctly with a quantized model and group offloading. + + Args: + config_kwargs: Quantization config parameters + use_stream: Whether to use CUDA streams for offloading + """ + torch._dynamo.config.cache_size_limit = 1000 + + model = self._create_quantized_model(config_kwargs) + model.eval() + + if not hasattr(model, "enable_group_offload"): + pytest.skip("Model does not support group offloading") + + group_offload_kwargs = { + "onload_device": torch.device(torch_device), + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": use_stream, + } + model.enable_group_offload(**group_offload_kwargs) + model = torch.compile(model) + + inputs = self.get_dummy_inputs() + output = model(**inputs, return_dict=False)[0] + assert output is not None, "Model output is None" + assert not torch.isnan(output).any(), "Model output contains NaN" + + +@is_bitsandbytes +@require_accelerator +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +class BitsAndBytesCompileTesterMixin(BitsAndBytesConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with BitsAndBytes quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: bitsandbytes + Use `pytest -m "not bitsandbytes"` to skip these tests + """ + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_torch_compile(self, config_name): + self._test_torch_compile(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["4bit_nf4"], ids=["4bit_nf4"]) + def test_bnb_torch_compile_with_group_offload(self, config_name): + self._test_torch_compile_with_group_offload(BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]) + + +@is_quanto +@require_quanto +@require_accelerate +@require_accelerator +class QuantoCompileTesterMixin(QuantoConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with Quanto quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: quanto + Use `pytest -m "not quanto"` to skip these tests + """ + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_torch_compile(self, weight_type_name): + self._test_torch_compile(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + @pytest.mark.parametrize("weight_type_name", ["int8"], ids=["int8"]) + def test_quanto_torch_compile_with_group_offload(self, weight_type_name): + self._test_torch_compile_with_group_offload(QuantoConfigMixin.QUANTO_WEIGHT_TYPES[weight_type_name]) + + +@is_torchao +@require_accelerator +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoCompileTesterMixin(TorchAoConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with TorchAO quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: torchao + Use `pytest -m "not torchao"` to skip these tests + """ + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_torch_compile(self, quant_type): + self._test_torch_compile(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + @pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"]) + def test_torchao_torch_compile_with_group_offload(self, quant_type): + self._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type]) + + +@is_gguf +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +class GGUFCompileTesterMixin(GGUFConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with GGUF quantized models. + + Expected from config mixin: + - model_class: The model class to test + + Required properties (must be implemented by subclasses): + - gguf_filename: URL or path to the GGUF file + + Expected methods from config mixin: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: gguf + Use `pytest -m "not gguf"` to skip these tests + """ + + def test_gguf_torch_compile(self): + self._test_torch_compile({"compute_dtype": torch.bfloat16}) + + def test_gguf_torch_compile_with_group_offload(self): + self._test_torch_compile_with_group_offload({"compute_dtype": torch.bfloat16}) + + +@is_modelopt +@require_accelerator +@require_accelerate +@require_modelopt_version_greater_or_equal("0.33.1") +class ModelOptCompileTesterMixin(ModelOptConfigMixin, QuantizationCompileTesterMixin): + """ + Mixin class for testing torch.compile with NVIDIA ModelOpt quantized models. + + Expected class attributes: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: (Optional) Dict of kwargs to pass to from_pretrained + + Expected methods to be implemented by subclasses: + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: modelopt + Use `pytest -m "not modelopt"` to skip these tests + """ + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_torch_compile(self, config_name): + self._test_torch_compile(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) + + @pytest.mark.parametrize("config_name", ["fp8"], ids=["fp8"]) + def test_modelopt_torch_compile_with_group_offload(self, config_name): + self._test_torch_compile_with_group_offload(ModelOptConfigMixin.MODELOPT_CONFIGS[config_name]) diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py new file mode 100644 index 000000000000..e2b9dadb6140 --- /dev/null +++ b/tests/models/testing_utils/single_file.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc + +import torch +from huggingface_hub import hf_hub_download, snapshot_download + +from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name + +from ...testing_utils import ( + backend_empty_cache, + is_single_file, + nightly, + require_torch_accelerator, + torch_device, +) +from .common import check_device_map_is_respected + + +def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir): + """Download a single file checkpoint from the Hub to a temporary directory.""" + path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir) + return path + + +def download_diffusers_config(pretrained_model_name_or_path, tmpdir): + """Download diffusers config files (excluding weights) from a repository.""" + path = snapshot_download( + pretrained_model_name_or_path, + ignore_patterns=[ + "**/*.ckpt", + "*.ckpt", + "**/*.bin", + "*.bin", + "**/*.pt", + "*.pt", + "**/*.safetensors", + "*.safetensors", + ], + allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"], + local_dir=tmpdir, + ) + return path + + +@nightly +@require_torch_accelerator +@is_single_file +class SingleFileTesterMixin: + """ + Mixin class for testing single file loading for models. + + Required properties (must be implemented by subclasses): + - ckpt_path: Path or Hub path to the single file checkpoint + + Optional properties: + - torch_dtype: torch dtype to use for testing (default: None) + - alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None) + + Expected from config mixin: + - model_class: The model class to test + - pretrained_model_name_or_path: Hub repository ID for the pretrained model + - pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder) + + Pytest mark: single_file + Use `pytest -m "not single_file"` to skip these tests + """ + + # ==================== Required Properties ==================== + + @property + def ckpt_path(self) -> str: + """Path or Hub path to the single file checkpoint. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `ckpt_path` property.") + + # ==================== Optional Properties ==================== + + @property + def torch_dtype(self) -> torch.dtype | None: + """torch dtype to use for single file testing.""" + return None + + @property + def alternate_ckpt_paths(self) -> list[str] | None: + """List of alternate checkpoint paths for variant testing.""" + return None + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_model_config(self): + pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs} + single_file_kwargs = {"device": torch_device} + + if self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert model.config[param_name] == param_value, ( + f"{param_name} differs between pretrained loading and single file loading: " + f"pretrained={model.config[param_name]}, single_file={param_value}" + ) + + def test_single_file_model_parameters(self): + pretrained_kwargs = {"device_map": str(torch_device), **self.pretrained_model_kwargs} + single_file_kwargs = {"device": torch_device} + + if self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs["torch_dtype"] = self.torch_dtype + + # Load pretrained model, get state dict on CPU, then free GPU memory + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + del model + gc.collect() + backend_empty_cache(torch_device) + + # Load single file model, get state dict on CPU + model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) + state_dict_single_file = {k: v.cpu() for k, v in model_single_file.state_dict().items()} + del model_single_file + gc.collect() + backend_empty_cache(torch_device) + + assert set(state_dict.keys()) == set(state_dict_single_file.keys()), ( + "Model parameters keys differ between pretrained and single file loading. " + f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. " + f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}" + ) + + for key in state_dict.keys(): + param = state_dict[key] + param_single_file = state_dict_single_file[key] + + assert param.shape == param_single_file.shape, ( + f"Parameter shape mismatch for {key}: " + f"pretrained {param.shape} vs single file {param_single_file.shape}" + ) + + assert torch.equal(param, param_single_file), f"Parameter values differ for {key}" + + def test_single_file_loading_local_files_only(self, tmp_path): + single_file_kwargs = {} + + if self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path)) + + model_single_file = self.model_class.from_single_file( + local_ckpt_path, local_files_only=True, **single_file_kwargs + ) + + assert model_single_file is not None, "Failed to load model with local_files_only=True" + + def test_single_file_loading_with_diffusers_config(self): + single_file_kwargs = {} + + if self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs.update(self.pretrained_model_kwargs) + + # Load with config parameter + model_single_file = self.model_class.from_single_file( + self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs + ) + + # Load pretrained for comparison + pretrained_kwargs = {**self.pretrained_model_kwargs} + if self.torch_dtype: + pretrained_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) + + # Compare configs + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert model.config[param_name] == param_value, ( + f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}" + ) + + def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path): + single_file_kwargs = {} + + if self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + single_file_kwargs.update(self.pretrained_model_kwargs) + + pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) + local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path)) + local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path)) + + model_single_file = self.model_class.from_single_file( + local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs + ) + + assert model_single_file is not None, "Failed to load model with config and local_files_only=True" + + def test_single_file_loading_dtype(self): + for dtype in [torch.float32, torch.float16]: + if torch_device == "mps" and dtype == torch.bfloat16: + continue + + model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype) + + assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}" + + # Cleanup + del model_single_file + gc.collect() + backend_empty_cache(torch_device) + + def test_checkpoint_variant_loading(self): + if not self.alternate_ckpt_paths: + return + + for ckpt_path in self.alternate_ckpt_paths: + backend_empty_cache(torch_device) + + single_file_kwargs = {} + if self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs) + + assert model is not None, f"Failed to load checkpoint from {ckpt_path}" + + del model + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_loading_with_device_map(self): + single_file_kwargs = {"device_map": torch_device} + + if self.torch_dtype: + single_file_kwargs["torch_dtype"] = self.torch_dtype + + model = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) + + assert model is not None, "Failed to load model with device_map" + assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute when loaded with device_map" + assert model.hf_device_map is not None, "hf_device_map should not be None when loaded with device_map" + check_device_map_is_respected(model, model.hf_device_map) diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py new file mode 100644 index 000000000000..44cce6af68e5 --- /dev/null +++ b/tests/models/testing_utils/training.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import gc + +import pytest +import torch + +from diffusers.training_utils import EMAModel + +from ...testing_utils import ( + backend_empty_cache, + is_training, + require_torch_accelerator_with_training, + torch_all_close, + torch_device, +) + + +@is_training +@require_torch_accelerator_with_training +class TrainingTesterMixin: + """ + Mixin class for testing training functionality on models. + + Expected from config mixin: + - model_class: The model class to test + - output_shape: Tuple defining the expected output shape + + Expected methods from config mixin: + - get_init_dict(): Returns dict of arguments to initialize the model + - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Pytest mark: training + Use `pytest -m "not training"` to skip these tests + """ + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_training(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + output = model(**inputs_dict, return_dict=False)[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + + def test_training_with_ema(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + ema_model = EMAModel(model.parameters()) + + output = model(**inputs_dict, return_dict=False)[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + loss.backward() + ema_model.step(model.parameters()) + + def test_gradient_checkpointing(self): + if not self.model_class._supports_gradient_checkpointing: + pytest.skip("Gradient checkpointing is not supported.") + + init_dict = self.get_init_dict() + + # at init model should have gradient checkpointing disabled + model = self.model_class(**init_dict) + assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init" + + # check enable works + model.enable_gradient_checkpointing() + assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled" + + # check disable works + model.disable_gradient_checkpointing() + assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled" + + def test_gradient_checkpointing_is_applied(self, expected_set=None): + if not self.model_class._supports_gradient_checkpointing: + pytest.skip("Gradient checkpointing is not supported.") + + if expected_set is None: + pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.") + + init_dict = self.get_init_dict() + + model_class_copy = copy.copy(self.model_class) + model = model_class_copy(**init_dict) + model.enable_gradient_checkpointing() + + modules_with_gc_enabled = {} + for submodule in model.modules(): + if hasattr(submodule, "gradient_checkpointing"): + assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled" + modules_with_gc_enabled[submodule.__class__.__name__] = True + + assert set(modules_with_gc_enabled.keys()) == expected_set, ( + f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}" + ) + assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled" + + def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None): + if not self.model_class._supports_gradient_checkpointing: + pytest.skip("Gradient checkpointing is not supported.") + + if skip is None: + skip = set() + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + inputs_dict_copy = copy.deepcopy(inputs_dict) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + model.to(torch_device) + + assert not model.is_gradient_checkpointing and model.training + + out = model(**inputs_dict, return_dict=False)[0] + + # run the backwards pass on the model + model.zero_grad() + + labels = torch.randn_like(out) + loss = (out - labels).mean() + loss.backward() + + # re-instantiate the model now enabling gradient checkpointing + torch.manual_seed(0) + model_2 = self.model_class(**init_dict) + # clone model + model_2.load_state_dict(model.state_dict()) + model_2.to(torch_device) + model_2.enable_gradient_checkpointing() + + assert model_2.is_gradient_checkpointing and model_2.training + + out_2 = model_2(**inputs_dict_copy, return_dict=False)[0] + + # run the backwards pass on the model + model_2.zero_grad() + loss_2 = (out_2 - labels).mean() + loss_2.backward() + + # compare the output and parameters gradients + assert (loss - loss_2).abs() < loss_tolerance, ( + f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}" + ) + + named_params = dict(model.named_parameters()) + named_params_2 = dict(model_2.named_parameters()) + + for name, param in named_params.items(): + if "post_quant_conv" in name: + continue + if name in skip: + continue + if param.grad is None: + continue + + assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), ( + f"Gradient mismatch for {name}" + ) + + def test_mixed_precision_training(self): + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.train() + + # Test with float16 + if torch.device(torch_device).type != "cpu": + with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16): + output = model(**inputs_dict, return_dict=False)[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() + + # Test with bfloat16 + if torch.device(torch_device).type != "cpu": + model.zero_grad() + with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): + output = model(**inputs_dict, return_dict=False)[0] + + noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device) + loss = torch.nn.functional.mse_loss(output, noise) + + loss.backward() diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 3ab02f797b5b..2d39dadfcad1 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -13,23 +13,52 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from typing import Any +import pytest import torch from diffusers import FluxTransformer2DModel -from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0 from diffusers.models.embeddings import ImageProjection - -from ...testing_utils import enable_full_determinism, is_peft_available, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesCompileTesterMixin, + BitsAndBytesTesterMixin, + ContextParallelTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + GGUFCompileTesterMixin, + GGUFTesterMixin, + IPAdapterTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelOptCompileTesterMixin, + ModelOptTesterMixin, + ModelTesterMixin, + PyramidAttentionBroadcastTesterMixin, + QuantoCompileTesterMixin, + QuantoTesterMixin, + SingleFileTesterMixin, + TorchAoCompileTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -def create_flux_ip_adapter_state_dict(model): - # "ip_adapter" (cross-attention weights) +# TODO: This standalone function maintains backward compatibility with pipeline tests +# (tests/pipelines/test_pipelines_common.py) and will be refactored. +def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]: + """Create a dummy IP Adapter state dict for Flux transformer testing.""" ip_cross_attn_state_dict = {} key_id = 0 @@ -39,7 +68,7 @@ def create_flux_ip_adapter_state_dict(model): joint_attention_dim = model.config["joint_attention_dim"] hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"] - sd = FluxIPAdapterJointAttnProcessor2_0( + sd = FluxIPAdapterAttnProcessor( hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0 ).state_dict() ip_cross_attn_state_dict.update( @@ -50,11 +79,8 @@ def create_flux_ip_adapter_state_dict(model): f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"], } ) - key_id += 1 - # "image_proj" (ImageProjection layer weights) - image_projection = ImageProjection( cross_attention_dim=model.config["joint_attention_dim"], image_embed_dim=( @@ -75,57 +101,45 @@ def create_flux_ip_adapter_state_dict(model): ) del sd - ip_state_dict = {} - ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) - return ip_state_dict + return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict} -class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = FluxTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] +class FluxTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return FluxTransformer2DModel - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-flux-pipe" @property - def dummy_input(self): - return self.prepare_dummy_input() + def pretrained_model_kwargs(self): + return {"subfolder": "transformer"} @property - def input_shape(self): + def output_shape(self) -> tuple[int, int]: return (16, 4) @property - def output_shape(self): + def input_shape(self) -> tuple[int, int]: return (16, 4) - def prepare_dummy_input(self, height=4, width=4): - batch_size = 1 - num_latent_channels = 4 - num_image_channels = 3 - sequence_length = 48 - embedding_dim = 32 + @property + def model_split_percents(self) -> list: + return [0.9] - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) - text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) - image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) - timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + @property + def main_input_name(self) -> str: + return "hidden_states" - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "img_ids": image_ids, - "txt_ids": text_ids, - "pooled_projections": pooled_prompt_embeds, - "timestep": timestep, - } + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict[str, int | list[int]]: + """Return Flux model initialization arguments.""" + return { "patch_size": 1, "in_channels": 4, "num_layers": 1, @@ -137,11 +151,40 @@ def prepare_init_args_and_inputs_for_common(self): "axes_dims_rope": [4, 4, 8], } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 1 + height = width = 4 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 48 + embedding_dim = 32 + return { + "hidden_states": randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ), + "pooled_projections": randn_tensor( + (batch_size, embedding_dim), generator=self.generator, device=torch_device + ), + "img_ids": randn_tensor( + (height * width, num_image_channels), generator=self.generator, device=torch_device + ), + "txt_ids": randn_tensor( + (sequence_length, num_image_channels), generator=self.generator, device=torch_device + ), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + +class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): def test_deprecated_inputs_img_txt_ids_3d(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + """Test that deprecated 3D img_ids and txt_ids still work.""" + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) model.to(torch_device) model.eval() @@ -162,63 +205,228 @@ def test_deprecated_inputs_img_txt_ids_3d(self): with torch.no_grad(): output_2 = model(**inputs_dict).to_tuple()[0] - self.assertEqual(output_1.shape, output_2.shape) - self.assertTrue( - torch.allclose(output_1, output_2, atol=1e-5), - msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs", + assert output_1.shape == output_2.shape + assert torch.allclose(output_1, output_2, atol=1e-5), ( + "output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) " + "are not equal as them as 2d inputs" ) - def test_gradient_checkpointing_is_applied(self): - expected_set = {"FluxTransformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - # The test exists for cases like - # https://github.com/huggingface/diffusers/issues/11874 - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_exclude_modules(self): - from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict - - lora_rank = 4 - target_module = "single_transformer_blocks.0.proj_out" - adapter_name = "foo" - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - state_dict = model.state_dict() - target_mod_shape = state_dict[f"{target_module}.weight"].shape - lora_state_dict = { - f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22, - f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33, + +class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for Flux Transformer.""" + + +class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): + """Training tests for Flux Transformer.""" + + +class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for Flux Transformer.""" + + +class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin): + """Context Parallel inference tests for Flux Transformer""" + + +class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): + """IP Adapter tests for Flux Transformer.""" + + @property + def ip_adapter_processor_cls(self): + return FluxIPAdapterAttnProcessor + + def modify_inputs_for_ip_adapter(self, model, inputs_dict): + torch.manual_seed(0) + # Create dummy image embeds for IP adapter + cross_attention_dim = getattr(model.config, "joint_attention_dim", 32) + image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device) + + inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}}) + + return inputs_dict + + def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]: + return create_flux_ip_adapter_state_dict(model) + + +class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for Flux Transformer.""" + + +class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for Flux Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for LoRA hotswap tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 32 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device), + "pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device), + "img_ids": randn_tensor((height * width, num_image_channels), device=torch_device), + "txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), + } + + +class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + """Override to support dynamic height/width for compilation tests.""" + batch_size = 1 + num_latent_channels = 4 + num_image_channels = 3 + sequence_length = 24 + embedding_dim = 32 + + return { + "hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device), + "encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device), + "pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device), + "img_ids": randn_tensor((height * width, num_image_channels), device=torch_device), + "txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device), + "timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size), } - # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter). - config = LoraConfig( - r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"] - ) - inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict) - set_peft_model_state_dict(model, lora_state_dict, adapter_name) - retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) - assert len(retrieved_lora_state_dict) == len(lora_state_dict) - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all() - assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all() -class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = FluxTransformer2DModel - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] +class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): + @property + def ckpt_path(self): + return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + + @property + def alternate_ckpt_paths(self): + return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + + @property + def pretrained_model_name_or_path(self): + return "black-forest-labs/FLUX.1-dev" + + +class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for Flux Transformer.""" + + +class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin): + """Quanto quantization tests for Flux Transformer.""" + + @property + def pretrained_model_name_or_path(self): + return "hf-internal-testing/tiny-flux-transformer" + + @property + def pretrained_model_kwargs(self): + return {} + + +class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for Flux Transformer.""" + + +class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin): + @property + def gguf_filename(self): + return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" + + @property + def torch_dtype(self): + return torch.bfloat16 + + def get_dummy_inputs(self): + """Override to provide inputs matching the real FLUX model dimensions.""" + return { + "hidden_states": randn_tensor( + (1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ), + "encoder_hidden_states": randn_tensor( + (1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ), + "pooled_projections": randn_tensor( + (1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + "img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype), + "txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype), + "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), + } + + +class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin): + """Quanto + compile tests for Flux Transformer.""" + + +class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin): + """TorchAO + compile tests for Flux Transformer.""" + + +class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin): + @property + def gguf_filename(self): + return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" + + @property + def torch_dtype(self): + return torch.bfloat16 + + def get_dummy_inputs(self): + """Override to provide inputs matching the real FLUX model dimensions.""" + return { + "hidden_states": randn_tensor( + (1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ), + "encoder_hidden_states": randn_tensor( + (1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ), + "pooled_projections": randn_tensor( + (1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype + ), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + "img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype), + "txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype), + "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), + } + + +class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin): + """ModelOpt quantization tests for Flux Transformer.""" + + +class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin): + """ModelOpt + compile tests for Flux Transformer.""" + + +@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes") +class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin): + """BitsAndBytes + compile tests for Flux Transformer.""" + - def prepare_init_args_and_inputs_for_common(self): - return FluxTransformerTests().prepare_init_args_and_inputs_for_common() +class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin): + """PyramidAttentionBroadcast cache tests for Flux Transformer.""" - def prepare_dummy_input(self, height, width): - return FluxTransformerTests().prepare_dummy_input(height=height, width=width) +class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin): + """FirstBlockCache tests for Flux Transformer.""" -class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase): - model_class = FluxTransformer2DModel - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] - def prepare_init_args_and_inputs_for_common(self): - return FluxTransformerTests().prepare_init_args_and_inputs_for_common() +class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin): + """FasterCache tests for Flux Transformer.""" - def prepare_dummy_input(self, height, width): - return FluxTransformerTests().prepare_dummy_input(height=height, width=width) + # Flux is guidance distilled, so we can test at model level without CFG batch handling + FASTER_CACHE_CONFIG = { + "spatial_attention_block_skip_range": 2, + "spatial_attention_timestep_skip_range": (-1, 901), + "tensor_format": "BCHW", + "is_guidance_distilled": True, + } diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py new file mode 100644 index 000000000000..af9ef0623891 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTX2VideoTransformer3DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + # Common + batch_size = 2 + + # Video + num_frames = 2 + num_channels = 4 + height = 16 + width = 16 + + # Audio + audio_num_frames = 9 + audio_num_channels = 2 + num_mel_bins = 2 + + # Text + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) + audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to( + torch_device + ) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.rand((batch_size,)).to(torch_device) * 1000 + + return { + "hidden_states": hidden_states, + "audio_hidden_states": audio_hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "audio_encoder_hidden_states": audio_encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + "num_frames": num_frames, + "height": height, + "width": width, + "audio_num_frames": audio_num_frames, + "fps": 25.0, + } + + @property + def input_shape(self): + return (512, 4) + + @property + def output_shape(self): + return (512, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "patch_size": 1, + "patch_size_t": 1, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "audio_in_channels": 4, + "audio_out_channels": 4, + "audio_num_attention_heads": 2, + "audio_attention_head_dim": 4, + "audio_cross_attention_dim": 8, + "num_layers": 2, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 16, + "rope_double_precision": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LTX2VideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + # def test_ltx2_consistency(self, seed=0, dtype=torch.float32): + # torch.manual_seed(seed) + # init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # # Calculate dummy inputs in a custom manner to ensure compatibility with original code + # batch_size = 2 + # num_frames = 9 + # latent_frames = 2 + # text_embedding_dim = 16 + # text_seq_len = 16 + # fps = 25.0 + # sampling_rate = 16000.0 + # hop_length = 160.0 + + # sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000 + # timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) + + # num_channels = 4 + # latent_height = 4 + # latent_width = 4 + # hidden_states = torch.randn( + # (batch_size, num_channels, latent_frames, latent_height, latent_width), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + # # Patchify video latents (with patch_size (1, 1, 1)) + # hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) + # hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + # encoder_hidden_states = torch.randn( + # (batch_size, text_seq_len, text_embedding_dim), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + + # audio_num_channels = 2 + # num_mel_bins = 2 + # latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) + # audio_hidden_states = torch.randn( + # (batch_size, audio_num_channels, latent_length, num_mel_bins), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + # # Patchify audio latents + # audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) + # audio_encoder_hidden_states = torch.randn( + # (batch_size, text_seq_len, text_embedding_dim), + # generator=torch.manual_seed(seed), + # dtype=dtype, + # device="cpu", + # ) + + # inputs_dict = { + # "hidden_states": hidden_states.to(device=torch_device), + # "audio_hidden_states": audio_hidden_states.to(device=torch_device), + # "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), + # "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), + # "timestep": timestep, + # "num_frames": latent_frames, + # "height": latent_height, + # "width": latent_width, + # "audio_num_frames": num_frames, + # "fps": 25.0, + # } + + # model = self.model_class.from_pretrained( + # "diffusers-internal-dev/dummy-ltx2", + # subfolder="transformer", + # device_map="cpu", + # ) + # # torch.manual_seed(seed) + # # model = self.model_class(**init_dict) + # model.to(torch_device) + # model.eval() + + # with attention_backend("native"): + # with torch.no_grad(): + # output = model(**inputs_dict) + + # video_output, audio_output = output.to_tuple() + + # self.assertIsNotNone(video_output) + # self.assertIsNotNone(audio_output) + + # # input & output have to have the same shape + # video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) + # self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") + # audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) + # self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") + + # # Check against expected slice + # # fmt: off + # video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) + # audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) + # # fmt: on + + # video_output_flat = video_output.cpu().flatten().float() + # video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) + # self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) + + # audio_output_flat = audio_output.cpu().flatten().float() + # audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) + # self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) + + +class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): + model_class = LTX2VideoTransformer3DModel + + def prepare_init_args_and_inputs_for_common(self): + return LTX2TransformerTests().prepare_init_args_and_inputs_for_common() diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..e6b19377b14f 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -15,10 +15,10 @@ import unittest -import pytest import torch from diffusers import QwenImageTransformer2DModel +from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -68,7 +68,6 @@ def prepare_dummy_input(self, height=4, width=4): "encoder_hidden_states_mask": encoder_hidden_states_mask, "timestep": timestep, "img_shapes": img_shapes, - "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(), } def prepare_init_args_and_inputs_for_common(self): @@ -91,6 +90,180 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_infers_text_seq_len_from_mask(self): + """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + + rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + + # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) + self.assertIsInstance(rope_text_seq_len, int) + + # Verify per_sample_len is computed correctly (max valid position + 1 = 2) + self.assertIsInstance(per_sample_len, torch.Tensor) + self.assertEqual(int(per_sample_len.max().item()), 2) + + # Verify mask is normalized to bool dtype + self.assertTrue(normalized_mask.dtype == torch.bool) + self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values + + # Verify rope_text_seq_len is at least the sequence length + self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + + # Test 2: Verify model runs successfully with inferred values + inputs["encoder_hidden_states_mask"] = normalized_mask + with torch.no_grad(): + output = model(**inputs) + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Different mask pattern (padding at beginning) + encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() + encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding + encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + + rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask2 + ) + + # Max valid position is 6 (last token), so per_sample_len should be 7 + self.assertEqual(int(per_sample_len2.max().item()), 7) + self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + + # Test 4: No mask provided (None case) + rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], None + ) + self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(rope_text_seq_len_none, int) + self.assertIsNone(per_sample_len_none) + self.assertIsNone(normalized_mask_none) + + def test_non_contiguous_attention_mask(self): + """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. + encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() + # Pattern: [True, False, True, False, True, False, False] + encoder_hidden_states_mask[:, 1] = 0 + encoder_hidden_states_mask[:, 3] = 0 + encoder_hidden_states_mask[:, 5:] = 0 + + inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( + inputs["encoder_hidden_states"], encoder_hidden_states_mask + ) + self.assertEqual(int(per_sample_len.max().item()), 5) + self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) + self.assertIsInstance(inferred_rope_len, int) + self.assertTrue(normalized_mask.dtype == torch.bool) + + inputs["encoder_hidden_states_mask"] = normalized_mask + + with torch.no_grad(): + output = model(**inputs) + + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_txt_seq_lens_deprecation(self): + """Test that passing txt_seq_lens raises a deprecation warning.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + # Prepare inputs with txt_seq_lens (deprecated parameter) + txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] + + # Remove encoder_hidden_states_mask to use the deprecated path + inputs_with_deprecated = inputs.copy() + inputs_with_deprecated.pop("encoder_hidden_states_mask") + inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens + + # Test that deprecation warning is raised + with self.assertWarns(FutureWarning) as warning_context: + with torch.no_grad(): + output = model(**inputs_with_deprecated) + + # Verify the warning message mentions the deprecation + warning_message = str(warning_context.warning) + self.assertIn("txt_seq_lens", warning_message) + self.assertIn("deprecated", warning_message) + self.assertIn("encoder_hidden_states_mask", warning_message) + + # Verify the model still works correctly despite the deprecation + self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + + def test_layered_model_with_mask(self): + """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" + # Create layered model config + init_dict = { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 3, + "joint_attention_dim": 16, + "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) + "use_layer3d_rope": True, # Enable layered RoPE + "use_additional_t_cond": True, # Enable additional time conditioning + } + + model = self.model_class(**init_dict).to(torch_device) + + # Verify the model uses QwenEmbedLayer3DRope + from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope + + self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + + # Test single generation with layered structure + batch_size = 1 + text_seq_len = 7 + img_h, img_w = 4, 4 + layers = 4 + + # For layered model: (layers + 1) because we have N layers + 1 combined image + hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) + + # Create mask with some padding + encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) + encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + + timestep = torch.tensor([1.0]).to(torch_device) + + # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) + addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) + + # Layer structure: 4 layers + 1 condition image + img_shapes = [ + [ + (1, img_h, img_w), # layer 0 + (1, img_h, img_w), # layer 1 + (1, img_h, img_w), # layer 2 + (1, img_h, img_w), # layer 3 + (1, img_h, img_w), # condition image (last one gets special treatment) + ] + ] + + with torch.no_grad(): + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + timestep=timestep, + img_shapes=img_shapes, + additional_t_cond=addition_t_cond, + ) + + self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel @@ -101,6 +274,76 @@ def prepare_init_args_and_inputs_for_common(self): def prepare_dummy_input(self, height, width): return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) - @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True) def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() + + def test_torch_compile_with_and_without_mask(self): + """Test that torch.compile works with both None mask and padding mask.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model.compile(mode="default", fullgraph=True) + + # Test 1: Run with None mask (no padding, all tokens are valid) + inputs_no_mask = inputs.copy() + inputs_no_mask["encoder_hidden_states_mask"] = None + + # First run to allow compilation + with torch.no_grad(): + output_no_mask = model(**inputs_no_mask) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_no_mask_2 = model(**inputs_no_mask) + + self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 2: Run with all-ones mask (should behave like None) + inputs_all_ones = inputs.copy() + # Keep the all-ones mask + self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + + # First run to allow compilation + with torch.no_grad(): + output_all_ones = model(**inputs_all_ones) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_all_ones_2 = model(**inputs_all_ones) + + self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Run with actual padding mask (has zeros) + inputs_with_padding = inputs.copy() + mask_with_padding = inputs["encoder_hidden_states_mask"].clone() + mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding + + inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding + + # First run to allow compilation + with torch.no_grad(): + output_with_padding = model(**inputs_with_padding) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_with_padding_2 = model(**inputs_with_padding) + + self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Verify that outputs are different (mask should affect results) + self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)) diff --git a/tests/modular_pipelines/flux2/__init__.py b/tests/modular_pipelines/flux2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py new file mode 100644 index 000000000000..8fd529e97e71 --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py @@ -0,0 +1,93 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2AutoBlocks, + Flux2ModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2ModularPipeline + pipeline_blocks_class = Flux2AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2ModularPipeline + pipeline_blocks_class = Flux2AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular" + + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py new file mode 100644 index 000000000000..26653b20f8c4 --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2KleinAutoBlocks, + Flux2KleinModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular" + + params = frozenset(["prompt", "height", "width", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return diff --git a/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py new file mode 100644 index 000000000000..701dd0fed896 --- /dev/null +++ b/tests/modular_pipelines/flux2/test_modular_pipeline_flux2_klein_base.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import PIL +import pytest + +from diffusers.modular_pipelines import ( + Flux2KleinBaseAutoBlocks, + Flux2KleinModularPipeline, +) + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinBaseAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = Flux2KleinModularPipeline + pipeline_blocks_class = Flux2KleinBaseAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular" + + params = frozenset(["prompt", "height", "width", "image"]) + batch_params = frozenset(["prompt", "image"]) + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + # TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer + "max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch + "text_encoder_out_layers": (1,), + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "output_type": "pt", + } + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB") + inputs["image"] = init_image + + return inputs + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + @pytest.mark.skip(reason="batched inference is currently not supported") + def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001): + return diff --git a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py index 8d7600781b24..f4bd27b7ea47 100644 --- a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py +++ b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py @@ -26,6 +26,7 @@ QwenImageModularPipeline, ) +from ...testing_utils import torch_device from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin @@ -104,6 +105,16 @@ def get_dummy_inputs(self): inputs["image"] = PIL.Image.new("RGB", (32, 32), 0) return inputs + def test_multi_images_as_input(self): + inputs = self.get_dummy_inputs() + image = inputs.pop("image") + inputs["image"] = [image, image] + + pipe = self.get_pipeline().to(torch_device) + _ = pipe( + **inputs, + ) + @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True) def test_num_images_per_prompt(self): super().test_num_images_per_prompt() @@ -117,4 +128,4 @@ def test_inference_batch_single_identical(): super().test_inference_batch_single_identical() def test_guider_cfg(self): - super().test_guider_cfg(1e-3) + super().test_guider_cfg(1e-6) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index a33951dac538..a08ca2fb759c 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -8,6 +8,13 @@ import diffusers from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks from diffusers.guiders import ClassifierFreeGuidance +from diffusers.modular_pipelines.modular_pipeline_utils import ( + ComponentSpec, + ConfigSpec, + InputParam, + OutputParam, + generate_modular_model_card_content, +) from diffusers.utils import logging from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device @@ -165,7 +172,6 @@ def test_inference_batch_single_identical( expected_max_diff=1e-4, ): pipe = self.get_pipeline().to(torch_device) - inputs = self.get_dummy_inputs() # Reset generator in case it is has been used in self.get_dummy_inputs @@ -336,3 +342,239 @@ def test_guider_cfg(self, expected_max_diff=1e-2): assert out_cfg.shape == out_no_cfg.shape max_diff = torch.abs(out_cfg - out_no_cfg).max() assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference" + + +class TestModularModelCardContent: + def create_mock_block(self, name="TestBlock", description="Test block description"): + class MockBlock: + def __init__(self, name, description): + self.__class__.__name__ = name + self.description = description + self.sub_blocks = {} + + return MockBlock(name, description) + + def create_mock_blocks( + self, + class_name="TestBlocks", + description="Test pipeline description", + num_blocks=2, + components=None, + configs=None, + inputs=None, + outputs=None, + trigger_inputs=None, + model_name=None, + ): + class MockBlocks: + def __init__(self): + self.__class__.__name__ = class_name + self.description = description + self.sub_blocks = {} + self.expected_components = components or [] + self.expected_configs = configs or [] + self.inputs = inputs or [] + self.outputs = outputs or [] + self.trigger_inputs = trigger_inputs + self.model_name = model_name + + blocks = MockBlocks() + + # Add mock sub-blocks + for i in range(num_blocks): + block_name = f"block_{i}" + blocks.sub_blocks[block_name] = self.create_mock_block(f"Block{i}", f"Description for block {i}") + + return blocks + + def test_basic_model_card_content_structure(self): + """Test that all expected keys are present in the output.""" + blocks = self.create_mock_blocks() + content = generate_modular_model_card_content(blocks) + + expected_keys = [ + "pipeline_name", + "model_description", + "blocks_description", + "components_description", + "configs_section", + "inputs_description", + "outputs_description", + "trigger_inputs_section", + "tags", + ] + + for key in expected_keys: + assert key in content, f"Expected key '{key}' not found in model card content" + + assert isinstance(content["tags"], list), "Tags should be a list" + + def test_pipeline_name_generation(self): + """Test that pipeline name is correctly generated from blocks class name.""" + blocks = self.create_mock_blocks(class_name="StableDiffusionBlocks") + content = generate_modular_model_card_content(blocks) + + assert content["pipeline_name"] == "StableDiffusion Pipeline" + + def test_tags_generation_text_to_image(self): + """Test that text-to-image tags are correctly generated.""" + blocks = self.create_mock_blocks(trigger_inputs=None) + content = generate_modular_model_card_content(blocks) + + assert "modular-diffusers" in content["tags"] + assert "diffusers" in content["tags"] + assert "text-to-image" in content["tags"] + + def test_tags_generation_with_trigger_inputs(self): + """Test that tags are correctly generated based on trigger inputs.""" + # Test inpainting + blocks = self.create_mock_blocks(trigger_inputs=["mask", "prompt"]) + content = generate_modular_model_card_content(blocks) + assert "inpainting" in content["tags"] + + # Test image-to-image + blocks = self.create_mock_blocks(trigger_inputs=["image", "prompt"]) + content = generate_modular_model_card_content(blocks) + assert "image-to-image" in content["tags"] + + # Test controlnet + blocks = self.create_mock_blocks(trigger_inputs=["control_image", "prompt"]) + content = generate_modular_model_card_content(blocks) + assert "controlnet" in content["tags"] + + def test_tags_with_model_name(self): + """Test that model name is included in tags when present.""" + blocks = self.create_mock_blocks(model_name="stable-diffusion-xl") + content = generate_modular_model_card_content(blocks) + + assert "stable-diffusion-xl" in content["tags"] + + def test_components_description_formatting(self): + """Test that components are correctly formatted.""" + components = [ + ComponentSpec(name="vae", description="VAE component"), + ComponentSpec(name="text_encoder", description="Text encoder component"), + ] + blocks = self.create_mock_blocks(components=components) + content = generate_modular_model_card_content(blocks) + + assert "vae" in content["components_description"] + assert "text_encoder" in content["components_description"] + # Should be enumerated + assert "1." in content["components_description"] + + def test_components_description_empty(self): + """Test handling of pipelines without components.""" + blocks = self.create_mock_blocks(components=None) + content = generate_modular_model_card_content(blocks) + + assert "No specific components required" in content["components_description"] + + def test_configs_section_with_configs(self): + """Test that configs section is generated when configs are present.""" + configs = [ + ConfigSpec(name="num_train_timesteps", default=1000, description="Number of training timesteps"), + ] + blocks = self.create_mock_blocks(configs=configs) + content = generate_modular_model_card_content(blocks) + + assert "## Configuration Parameters" in content["configs_section"] + + def test_configs_section_empty(self): + """Test that configs section is empty when no configs are present.""" + blocks = self.create_mock_blocks(configs=None) + content = generate_modular_model_card_content(blocks) + + assert content["configs_section"] == "" + + def test_inputs_description_required_and_optional(self): + """Test that required and optional inputs are correctly formatted.""" + inputs = [ + InputParam(name="prompt", type_hint=str, required=True, description="The input prompt"), + InputParam(name="num_steps", type_hint=int, required=False, default=50, description="Number of steps"), + ] + blocks = self.create_mock_blocks(inputs=inputs) + content = generate_modular_model_card_content(blocks) + + assert "**Required:**" in content["inputs_description"] + assert "**Optional:**" in content["inputs_description"] + assert "prompt" in content["inputs_description"] + assert "num_steps" in content["inputs_description"] + assert "default: `50`" in content["inputs_description"] + + def test_inputs_description_empty(self): + """Test handling of pipelines without specific inputs.""" + blocks = self.create_mock_blocks(inputs=[]) + content = generate_modular_model_card_content(blocks) + + assert "No specific inputs defined" in content["inputs_description"] + + def test_outputs_description_formatting(self): + """Test that outputs are correctly formatted.""" + outputs = [ + OutputParam(name="images", type_hint=torch.Tensor, description="Generated images"), + ] + blocks = self.create_mock_blocks(outputs=outputs) + content = generate_modular_model_card_content(blocks) + + assert "images" in content["outputs_description"] + assert "Generated images" in content["outputs_description"] + + def test_outputs_description_empty(self): + """Test handling of pipelines without specific outputs.""" + blocks = self.create_mock_blocks(outputs=[]) + content = generate_modular_model_card_content(blocks) + + assert "Standard pipeline outputs" in content["outputs_description"] + + def test_trigger_inputs_section_with_triggers(self): + """Test that trigger inputs section is generated when present.""" + blocks = self.create_mock_blocks(trigger_inputs=["mask", "image"]) + content = generate_modular_model_card_content(blocks) + + assert "### Conditional Execution" in content["trigger_inputs_section"] + assert "`mask`" in content["trigger_inputs_section"] + assert "`image`" in content["trigger_inputs_section"] + + def test_trigger_inputs_section_empty(self): + """Test that trigger inputs section is empty when not present.""" + blocks = self.create_mock_blocks(trigger_inputs=None) + content = generate_modular_model_card_content(blocks) + + assert content["trigger_inputs_section"] == "" + + def test_blocks_description_with_sub_blocks(self): + """Test that blocks with sub-blocks are correctly described.""" + + class MockBlockWithSubBlocks: + def __init__(self): + self.__class__.__name__ = "ParentBlock" + self.description = "Parent block" + self.sub_blocks = { + "child1": self.create_child_block("ChildBlock1", "Child 1 description"), + "child2": self.create_child_block("ChildBlock2", "Child 2 description"), + } + + def create_child_block(self, name, desc): + class ChildBlock: + def __init__(self): + self.__class__.__name__ = name + self.description = desc + + return ChildBlock() + + blocks = self.create_mock_blocks() + blocks.sub_blocks["parent"] = MockBlockWithSubBlocks() + + content = generate_modular_model_card_content(blocks) + + assert "parent" in content["blocks_description"] + assert "child1" in content["blocks_description"] + assert "child2" in content["blocks_description"] + + def test_model_description_includes_block_count(self): + """Test that model description includes the number of blocks.""" + blocks = self.create_mock_blocks(num_blocks=5) + content = generate_modular_model_card_content(blocks) + + assert "5-block architecture" in content["model_description"] diff --git a/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py new file mode 100644 index 000000000000..9c5fd5be326d --- /dev/null +++ b/tests/modular_pipelines/test_modular_pipelines_custom_blocks.py @@ -0,0 +1,272 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +from collections import deque +from typing import List + +import numpy as np +import torch + +from diffusers import FluxTransformer2DModel +from diffusers.modular_pipelines import ( + ComponentSpec, + InputParam, + ModularPipelineBlocks, + OutputParam, + PipelineState, + WanModularPipeline, +) + +from ..testing_utils import nightly, require_torch, slow + + +class DummyCustomBlockSimple(ModularPipelineBlocks): + def __init__(self, use_dummy_model_component=False): + self.use_dummy_model_component = use_dummy_model_component + super().__init__() + + @property + def expected_components(self): + if self.use_dummy_model_component: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + else: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "output_prompt", + type_hint=str, + description="Modified prompt", + ) + ] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + old_prompt = block_state.prompt + block_state.output_prompt = "Modular diffusers + " + old_prompt + self.set_block_state(state, block_state) + + return components, state + + +CODE_STR = """ +from diffusers.modular_pipelines import ( + ComponentSpec, + InputParam, + ModularPipelineBlocks, + OutputParam, + PipelineState, + WanModularPipeline, +) +from typing import List + +class DummyCustomBlockSimple(ModularPipelineBlocks): + def __init__(self, use_dummy_model_component=False): + self.use_dummy_model_component = use_dummy_model_component + super().__init__() + + @property + def expected_components(self): + if self.use_dummy_model_component: + return [ComponentSpec("transformer", FluxTransformer2DModel)] + else: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("prompt", type_hint=str, required=True, description="Prompt to use")] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "output_prompt", + type_hint=str, + description="Modified prompt", + ) + ] + + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + old_prompt = block_state.prompt + block_state.output_prompt = "Modular diffusers + " + old_prompt + self.set_block_state(state, block_state) + + return components, state +""" + + +class TestModularCustomBlocks: + def _test_block_properties(self, block): + assert not block.expected_components + assert not block.intermediate_inputs + + actual_inputs = [inp.name for inp in block.inputs] + actual_intermediate_outputs = [out.name for out in block.intermediate_outputs] + assert actual_inputs == ["prompt"] + assert actual_intermediate_outputs == ["output_prompt"] + + def test_custom_block_properties(self): + custom_block = DummyCustomBlockSimple() + self._test_block_properties(custom_block) + + def test_custom_block_output(self): + custom_block = DummyCustomBlockSimple() + pipe = custom_block.init_pipeline() + prompt = "Diffusers is nice" + output = pipe(prompt=prompt) + + actual_inputs = [inp.name for inp in custom_block.inputs] + actual_intermediate_outputs = [out.name for out in custom_block.intermediate_outputs] + assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs) + + output_prompt = output.values["output_prompt"] + assert output_prompt.startswith("Modular diffusers + ") + + def test_custom_block_saving_loading(self): + custom_block = DummyCustomBlockSimple() + + with tempfile.TemporaryDirectory() as tmpdir: + custom_block.save_pretrained(tmpdir) + assert any("modular_config.json" in k for k in os.listdir(tmpdir)) + + with open(os.path.join(tmpdir, "modular_config.json"), "r") as f: + config = json.load(f) + auto_map = config["auto_map"] + assert auto_map == {"ModularPipelineBlocks": "test_modular_pipelines_custom_blocks.DummyCustomBlockSimple"} + + # For now, the Python script that implements the custom block has to be manually pushed to the Hub. + # This is why, we have to separately save the Python script here. + code_path = os.path.join(tmpdir, "test_modular_pipelines_custom_blocks.py") + with open(code_path, "w") as f: + f.write(CODE_STR) + + loaded_custom_block = ModularPipelineBlocks.from_pretrained(tmpdir, trust_remote_code=True) + + pipe = loaded_custom_block.init_pipeline() + prompt = "Diffusers is nice" + output = pipe(prompt=prompt) + + actual_inputs = [inp.name for inp in loaded_custom_block.inputs] + actual_intermediate_outputs = [out.name for out in loaded_custom_block.intermediate_outputs] + assert sorted(output.values) == sorted(actual_inputs + actual_intermediate_outputs) + + output_prompt = output.values["output_prompt"] + assert output_prompt.startswith("Modular diffusers + ") + + def test_custom_block_supported_components(self): + custom_block = DummyCustomBlockSimple(use_dummy_model_component=True) + pipe = custom_block.init_pipeline("hf-internal-testing/tiny-flux-kontext-pipe") + pipe.load_components() + + assert len(pipe.components) == 1 + assert pipe.component_names[0] == "transformer" + + def test_custom_block_loads_from_hub(self): + repo_id = "hf-internal-testing/tiny-modular-diffusers-block" + block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True) + self._test_block_properties(block) + + pipe = block.init_pipeline() + + prompt = "Diffusers is nice" + output = pipe(prompt=prompt) + output_prompt = output.values["output_prompt"] + assert output_prompt.startswith("Modular diffusers + ") + + +@slow +@nightly +@require_torch +class TestKreaCustomBlocksIntegration: + repo_id = "krea/krea-realtime-video" + + def test_loading_from_hub(self): + blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) + block_names = sorted(blocks.sub_blocks) + + assert block_names == sorted(["text_encoder", "before_denoise", "denoise", "decode"]) + + pipe = WanModularPipeline(blocks, self.repo_id) + pipe.load_components( + trust_remote_code=True, + device_map="cuda", + torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, + ) + assert len(pipe.components) == 7 + assert sorted(pipe.components) == sorted( + ["text_encoder", "tokenizer", "guider", "scheduler", "vae", "transformer", "video_processor"] + ) + + def test_forward(self): + blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True) + pipe = WanModularPipeline(blocks, self.repo_id) + pipe.load_components( + trust_remote_code=True, + device_map="cuda", + torch_dtype={"default": torch.bfloat16, "vae": torch.float16}, + ) + + num_frames_per_block = 2 + num_blocks = 2 + + state = PipelineState() + state.set("frame_cache_context", deque(maxlen=pipe.config.frame_cache_len)) + + prompt = ["a cat sitting on a boat"] + + for block in pipe.transformer.blocks: + block.self_attn.fuse_projections() + + for block_idx in range(num_blocks): + state = pipe( + state, + prompt=prompt, + num_inference_steps=2, + num_blocks=num_blocks, + num_frames_per_block=num_frames_per_block, + block_idx=block_idx, + generator=torch.manual_seed(42), + ) + current_frames = np.array(state.values["videos"][0]) + current_frames_flat = current_frames.flatten() + actual_slices = np.concatenate([current_frames_flat[:4], current_frames_flat[-4:]]).tolist() + + if block_idx == 0: + assert current_frames.shape == (5, 480, 832, 3) + expected_slices = np.array([211, 229, 238, 208, 195, 180, 188, 193]) + else: + assert current_frames.shape == (8, 480, 832, 3) + expected_slices = np.array([179, 203, 214, 176, 194, 181, 187, 191]) + + assert np.allclose(actual_slices, expected_slices) diff --git a/tests/pipelines/bria_fibo_edit/__init__.py b/tests/pipelines/bria_fibo_edit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py new file mode 100644 index 000000000000..5376c4b5e03f --- /dev/null +++ b/tests/pipelines/bria_fibo_edit/test_pipeline_bria_fibo_edit.py @@ -0,0 +1,192 @@ +# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer +from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM + +from diffusers import ( + AutoencoderKLWan, + BriaFiboEditPipeline, + FlowMatchEulerDiscreteScheduler, +) +from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel +from tests.pipelines.test_pipelines_common import PipelineTesterMixin + +from ...testing_utils import ( + enable_full_determinism, + torch_device, +) + + +enable_full_determinism() + + +class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BriaFiboEditPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + test_xformers_attention = False + test_layerwise_casting = False + test_group_offloading = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = BriaFiboTransformer2DModel( + patch_size=1, + in_channels=16, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=64, + text_encoder_dim=32, + pooled_projection_dim=None, + axes_dims_rope=[0, 4, 4], + ) + + vae = AutoencoderKLWan( + base_dim=80, + decoder_base_dim=128, + dim_mult=[1, 2, 4, 4], + dropout=0.0, + in_channels=12, + latents_mean=[0.0] * 16, + latents_std=[1.0] * 16, + is_residual=True, + num_res_blocks=2, + out_channels=12, + patch_size=2, + scale_factor_spatial=16, + scale_factor_temporal=4, + temperal_downsample=[False, True, True], + z_dim=16, + ) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32)) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + inputs = { + "prompt": '{"text": "A painting of a squirrel eating a burger","edit_instruction": "A painting of a squirrel eating a burger"}', + "negative_prompt": "bad, ugly", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 192, + "width": 336, + "output_type": "np", + } + image = Image.new("RGB", (336, 192), (255, 255, 255)) + inputs["image"] = image + return inputs + + @unittest.skip(reason="will not be supported due to dim-fusion") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_num_images_per_prompt(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_consistent(self): + pass + + @unittest.skip(reason="Batching is not supported yet") + def test_inference_batch_single_identical(self): + pass + + def test_bria_fibo_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = {"edit_instruction": "a different prompt"} + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + assert max_diff > 1e-6 + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (64, 64), (32, 64)] + for height, width in height_width_pairs: + expected_height = height + expected_width = width + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + + def test_bria_fibo_edit_mask(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((192, 336)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + output = pipe(**inputs).images[0] + + assert output.shape == (192, 336, 3) + + def test_bria_fibo_edit_mask_image_size_mismatch(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((64, 64)) * 255).astype(np.uint8), mode="L") + + inputs.update({"mask": mask}) + with self.assertRaisesRegex(ValueError, "Mask and image must have the same size"): + pipe(**inputs) + + def test_bria_fibo_edit_mask_no_image(self): + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe = pipe.to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + mask = Image.fromarray((np.ones((32, 32)) * 255).astype(np.uint8), mode="L") + + # Remove image from inputs if it's there (it shouldn't be by default from get_dummy_inputs) + inputs.pop("image", None) + inputs.update({"mask": mask}) + + with self.assertRaisesRegex(ValueError, "If mask is provided, image must also be provided"): + pipe(**inputs) diff --git a/tests/pipelines/cosmos/cosmos_guardrail.py b/tests/pipelines/cosmos/cosmos_guardrail.py index 4de14fbaaf9d..c9ef597fdb36 100644 --- a/tests/pipelines/cosmos/cosmos_guardrail.py +++ b/tests/pipelines/cosmos/cosmos_guardrail.py @@ -27,7 +27,7 @@ class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin): def __init__(self) -> None: super().__init__() - self._dtype = torch.float32 + self.register_buffer("_device_tracker", torch.zeros(1, dtype=torch.float32), persistent=False) def check_text_safety(self, prompt: str) -> bool: return True @@ -35,13 +35,14 @@ def check_text_safety(self, prompt: str) -> bool: def check_video_safety(self, frames: np.ndarray) -> np.ndarray: return frames - def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None: - self._dtype = dtype + def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None): + module = super().to(device=device, dtype=dtype) + return module @property def device(self) -> torch.device: - return None + return self._device_tracker.device @property def dtype(self) -> torch.dtype: - return self._dtype + return self._device_tracker.dtype diff --git a/tests/pipelines/cosmos/test_cosmos2_5_predict.py b/tests/pipelines/cosmos/test_cosmos2_5_predict.py new file mode 100644 index 000000000000..54d4edb485fe --- /dev/null +++ b/tests/pipelines/cosmos/test_cosmos2_5_predict.py @@ -0,0 +1,337 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import json +import os +import tempfile +import unittest + +import numpy as np +import torch +from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer + +from diffusers import ( + AutoencoderKLWan, + Cosmos2_5_PredictBasePipeline, + CosmosTransformer3DModel, + UniPCMultistepScheduler, +) + +from ...testing_utils import enable_full_determinism, torch_device +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np +from .cosmos_guardrail import DummyCosmosSafetyChecker + + +enable_full_determinism() + + +class Cosmos2_5_PredictBaseWrapper(Cosmos2_5_PredictBasePipeline): + @staticmethod + def from_pretrained(*args, **kwargs): + if "safety_checker" not in kwargs or kwargs["safety_checker"] is None: + safety_checker = DummyCosmosSafetyChecker() + device_map = kwargs.get("device_map", "cpu") + torch_dtype = kwargs.get("torch_dtype") + if device_map is not None or torch_dtype is not None: + safety_checker = safety_checker.to(device_map, dtype=torch_dtype) + kwargs["safety_checker"] = safety_checker + return Cosmos2_5_PredictBasePipeline.from_pretrained(*args, **kwargs) + + +class Cosmos2_5_PredictPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Cosmos2_5_PredictBaseWrapper + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CosmosTransformer3DModel( + in_channels=16 + 1, + out_channels=16, + num_attention_heads=2, + attention_head_dim=16, + num_layers=2, + mlp_ratio=2, + text_embed_dim=32, + adaln_lora_dim=4, + max_size=(4, 32, 32), + patch_size=(1, 2, 2), + rope_scale=(2.0, 1.0, 1.0), + concat_padding_mask=True, + extra_pos_embed_type="learnable", + ) + + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = UniPCMultistepScheduler() + + torch.manual_seed(0) + config = Qwen2_5_VLConfig( + text_config={ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + }, + vision_config={ + "depth": 2, + "hidden_size": 16, + "intermediate_size": 16, + "num_heads": 2, + "out_hidden_size": 16, + }, + hidden_size=16, + vocab_size=152064, + vision_end_token_id=151653, + vision_start_token_id=151652, + vision_token_id=151654, + ) + text_encoder = Qwen2_5_VLForConditionalGeneration(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": DummyCosmosSafetyChecker(), + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + "num_frames": 3, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_components_function(self): + init_components = self.get_dummy_components() + init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))} + pipe = self.pipeline_class(**init_components) + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (3, 3, 32, 32)) + self.assertTrue(torch.isfinite(generated_video).all()) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + _ = pipe(**inputs)[0] + + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not getattr(self, "test_attention_slicing", True): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_save_load_optional_components(self, expected_max_difference=1e-4): + self.pipeline_class._optional_components.remove("safety_checker") + super().test_save_load_optional_components(expected_max_difference=expected_max_difference) + self.pipeline_class._optional_components.append("safety_checker") + + def test_serialization_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + model_components = [ + component_name + for component_name, component in pipe.components.items() + if isinstance(component, torch.nn.Module) + ] + model_components.remove("safety_checker") + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + + with open(f"{tmpdir}/model_index.json", "r") as f: + config = json.load(f) + + for subfolder in os.listdir(tmpdir): + if not os.path.isfile(subfolder) and subfolder in model_components: + folder_path = os.path.join(tmpdir, subfolder) + is_folder = os.path.isdir(folder_path) and subfolder in config + assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + + def test_torch_dtype_dict(self): + components = self.get_dummy_components() + if not components: + self.skipTest("No dummy components defined.") + + pipe = self.pipeline_class(**components) + + specified_key = next(iter(components.keys())) + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: + pipe.save_pretrained(tmpdirname, safe_serialization=False) + torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained( + tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict + ) + + for name, component in loaded_pipe.components.items(): + if name == "safety_checker": + continue + if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"): + expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32)) + self.assertEqual( + component.dtype, + expected_dtype, + f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", + ) + + @unittest.skip( + "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in " + "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is " + "too large and slow to run on CI." + ) + def test_encode_prompt_works_in_isolation(self): + pass diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 1ddbd4ba3df8..f7476a21de57 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -27,8 +27,10 @@ FasterCacheTesterMixin, FirstBlockCacheTesterMixin, FluxIPAdapterTesterMixin, + MagCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, + TaylorSeerCacheTesterMixin, check_qkv_fused_layers_exist, ) @@ -39,6 +41,8 @@ class FluxPipelineFastTests( PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, + TaylorSeerCacheTesterMixin, + MagCacheTesterMixin, unittest.TestCase, ): pipeline_class = FluxPipeline diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein.py b/tests/pipelines/flux2/test_pipeline_flux2_klein.py new file mode 100644 index 000000000000..8ed9bf3d1e91 --- /dev/null +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein.py @@ -0,0 +1,183 @@ +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Flux2KleinPipeline, + Flux2Transformer2DModel, +) + +from ...testing_utils import torch_device +from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist + + +class Flux2KleinPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Flux2KleinPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = Flux2Transformer2DModel( + patch_size=1, + in_channels=4, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=16, + timestep_guidance_channels=256, + axes_dims_rope=[4, 4, 4, 4], + guidance_embeds=False, + ) + + # Create minimal Qwen3 config + config = Qwen3Config( + intermediate_size=16, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + torch.manual_seed(0) + text_encoder = Qwen3ForCausalLM(config) + + # Use a simple tokenizer for testing + tokenizer = Qwen2TokenizerFast.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ) + + torch.manual_seed(0) + vae = AutoencoderKLFlux2( + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "a dog is dancing", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "height": 8, + "width": 8, + "max_sequence_length": 64, + "output_type": "np", + "text_encoder_out_layers": (1,), + } + return inputs + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + original_image_slice = image[0, -3:, -3:, -1] + + pipe.transformer.fuse_qkv_projections() + self.assertTrue( + check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), + ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), + ) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_fused = image[0, -3:, -3:, -1] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice_disabled = image[0, -3:, -3:, -1] + + self.assertTrue( + np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), + ("Fusion of QKV projections shouldn't affect the outputs."), + ) + self.assertTrue( + np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), + ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."), + ) + self.assertTrue( + np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), + ("Original outputs should match when fused QKV projections are disabled."), + ) + + def test_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + self.assertEqual( + (output_height, output_width), + (expected_height, expected_width), + f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}", + ) + + def test_image_input(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()).to(device) + inputs = self.get_dummy_inputs(device) + + inputs["image"] = Image.new("RGB", (64, 64)) + image = pipe(**inputs).images.flatten() + generated_slice = np.concatenate([image[:8], image[-8:]]) + # fmt: off + expected_slice = np.array( + [ + 0.8255048 , 0.66054785, 0.6643694 , 0.67462724, 0.5494932 , 0.3480271 , 0.52535003, 0.44510138, 0.23549396, 0.21372932, 0.21166152, 0.63198495, 0.49942136, 0.39147034, 0.49156153, 0.3713916 + ] + ) + # fmt: on + assert np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4) + + @unittest.skip("Needs to be revisited") + def test_encode_prompt_works_in_isolation(self): + pass diff --git a/tests/pipelines/glm_image/__init__.py b/tests/pipelines/glm_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/glm_image/test_glm_image.py b/tests/pipelines/glm_image/test_glm_image.py new file mode 100644 index 000000000000..d907d082d275 --- /dev/null +++ b/tests/pipelines/glm_image/test_glm_image.py @@ -0,0 +1,316 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel +from diffusers.utils import is_transformers_version + +from ...testing_utils import enable_full_determinism, require_torch_accelerator, require_transformers_version_greater +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +if is_transformers_version(">=", "5.0.0.dev0"): + from transformers import GlmImageConfig, GlmImageForConditionalGeneration, GlmImageProcessor + + +enable_full_determinism() + + +@require_transformers_version_greater("4.57.4") +@require_torch_accelerator +class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = GlmImagePipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + test_attention_slicing = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + glm_config = GlmImageConfig( + text_config={ + "vocab_size": 168064, + "hidden_size": 32, + "intermediate_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "max_position_embeddings": 512, + "vision_vocab_size": 128, + "rope_parameters": {"mrope_section": (4, 2, 2)}, + }, + vision_config={ + "depth": 2, + "hidden_size": 32, + "num_heads": 2, + "image_size": 32, + "patch_size": 8, + "intermediate_size": 32, + }, + vq_config={"embed_dim": 32, "num_embeddings": 128, "latent_channels": 32}, + ) + + torch.manual_seed(0) + vision_language_encoder = GlmImageForConditionalGeneration(glm_config) + + processor = GlmImageProcessor.from_pretrained("zai-org/GLM-Image", subfolder="processor") + + torch.manual_seed(0) + # For GLM-Image, the relationship between components must satisfy: + # patch_size × vae_scale_factor = 16 (since AR tokens are upsampled 2× from d32) + transformer = GlmImageTransformer2DModel( + patch_size=2, + in_channels=4, + out_channels=4, + num_layers=2, + attention_head_dim=8, + num_attention_heads=2, + text_embed_dim=text_encoder.config.hidden_size, + time_embed_dim=16, + condition_dim=8, + prior_vq_quantizer_codebook_size=128, + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=(4, 8, 16, 16), + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=4, + sample_size=128, + latents_mean=[0.0] * 4, + latents_std=[1.0] * 4, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "tokenizer": tokenizer, + "processor": processor, + "text_encoder": text_encoder, + "vision_language_encoder": vision_language_encoder, + "vae": vae, + "transformer": transformer, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + height, width = 32, 32 + + inputs = { + "prompt": "A photo of a cat", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images[0] + generated_slice = image.flatten() + generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]]) + + # fmt: off + expected_slice = np.array( + [ + 0.5849247, 0.50278825, 0.45747858, 0.45895284, 0.43804976, 0.47044256, 0.5239665, 0.47904694, 0.3323419, 0.38725388, 0.28505728, 0.3161863, 0.35026982, 0.37546024, 0.4090118, 0.46629113 + ] + ) + # fmt: on + + self.assertEqual(image.shape, (3, 32, 32)) + self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4)) + + def test_inference_batch_single_identical(self): + """Test that batch=1 produces consistent results with the same seed.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + # Run twice with same seed + inputs1 = self.get_dummy_inputs(device, seed=42) + inputs2 = self.get_dummy_inputs(device, seed=42) + + image1 = pipe(**inputs1).images[0] + image2 = pipe(**inputs2).images[0] + + self.assertTrue(torch.allclose(image1, image2, atol=1e-4)) + + def test_inference_batch_multiple_prompts(self): + """Test batch processing with multiple prompts.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + height, width = 32, 32 + + inputs = { + "prompt": ["A photo of a cat", "A photo of a dog"], + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + } + + images = pipe(**inputs).images + + # Should return 2 images + self.assertEqual(len(images), 2) + self.assertEqual(images[0].shape, (3, 32, 32)) + self.assertEqual(images[1].shape, (3, 32, 32)) + + def test_num_images_per_prompt(self): + """Test generating multiple images per prompt.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + height, width = 32, 32 + + inputs = { + "prompt": "A photo of a cat", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + "num_images_per_prompt": 2, + } + + images = pipe(**inputs).images + + # Should return 2 images for single prompt + self.assertEqual(len(images), 2) + self.assertEqual(images[0].shape, (3, 32, 32)) + self.assertEqual(images[1].shape, (3, 32, 32)) + + def test_batch_with_num_images_per_prompt(self): + """Test batch prompts with num_images_per_prompt > 1.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + height, width = 32, 32 + + inputs = { + "prompt": ["A photo of a cat", "A photo of a dog"], + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.5, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + "num_images_per_prompt": 2, + } + + images = pipe(**inputs).images + + # Should return 4 images (2 prompts × 2 images per prompt) + self.assertEqual(len(images), 4) + + @unittest.skip("Needs to be revisited.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Needs to be revisited.") + def test_pipeline_level_group_offloading_inference(self): + pass + + @unittest.skip( + "Follow set of tests are relaxed because this pipeline doesn't guarantee same outputs for the same inputs in consecutive runs." + ) + def test_dict_tuple_outputs_equivalent(self): + pass + + @unittest.skip("Skipped") + def test_cpu_offload_forward_pass_twice(self): + pass + + @unittest.skip("Skipped") + def test_sequential_offload_forward_pass_twice(self): + pass + + @unittest.skip("Skipped") + def test_float16_inference(self): + pass + + @unittest.skip("Skipped") + def test_save_load_float16(self): + pass + + @unittest.skip("Skipped") + def test_save_load_local(self): + pass diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index 4bdf3ee20e1b..57a6daebad1f 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -33,6 +33,7 @@ FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, + TaylorSeerCacheTesterMixin, to_np, ) @@ -45,6 +46,7 @@ class HunyuanVideoPipelineFastTests( PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, + TaylorSeerCacheTesterMixin, unittest.TestCase, ): pipeline_class = HunyuanVideoPipeline diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py index 8a693e9c2dd0..df1dd2d9872c 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py @@ -248,6 +248,9 @@ def test_inference_batch_single_identical(self): def test_float16_inference(self): super().test_float16_inference(expected_max_diff=5e-1) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @is_flaky() def test_model_cpu_offload_forward_pass(self): super().test_inference_batch_single_identical(expected_max_diff=8e-4) diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py index 503fdb242dff..d3bfa4b3082c 100644 --- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py +++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py @@ -191,6 +191,9 @@ def test_float16_inference(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=1e-2) + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-3, rtol=1e-3) + @slow @require_torch_accelerator diff --git a/tests/pipelines/longcat_image/__init__.py b/tests/pipelines/longcat_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ltx2/__init__.py b/tests/pipelines/ltx2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ltx2/test_ltx2.py b/tests/pipelines/ltx2/test_ltx2.py new file mode 100644 index 000000000000..7d1a3bfc9987 --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2.py @@ -0,0 +1,289 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2Pipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2Pipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "output_type", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0263, 0.0528, 0.1217, 0.1104, 0.1632, 0.1072, 0.1789, 0.0949, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.5514, 0.5943, 0.4260, 0.5971, 0.4306, 0.6369, 0.3124, 0.6964, 0.5419, 0.2412, 0.3882, 0.4504, 0.1941, 0.3404, 0.6037, 0.2464 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0252, 0.0526, 0.1211, 0.1119, 0.1638, 0.1042, 0.1776, 0.0948, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) diff --git a/tests/pipelines/ltx2/test_ltx2_image2video.py b/tests/pipelines/ltx2/test_ltx2_image2video.py new file mode 100644 index 000000000000..3653e1cfc5e4 --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2_image2video.py @@ -0,0 +1,291 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, Gemma3ForConditionalGeneration + +from diffusers import ( + AutoencoderKLLTX2Audio, + AutoencoderKLLTX2Video, + FlowMatchEulerDiscreteScheduler, + LTX2ImageToVideoPipeline, + LTX2VideoTransformer3DModel, +) +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + +from ...testing_utils import enable_full_determinism +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTX2ImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "audio_latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_attention_slicing = False + test_xformers_attention = False + supports_dduf = False + + base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3" + + def get_dummy_components(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id) + + torch.manual_seed(0) + transformer = LTX2VideoTransformer3DModel( + in_channels=4, + out_channels=4, + patch_size=1, + patch_size_t=1, + num_attention_heads=2, + attention_head_dim=8, + cross_attention_dim=16, + audio_in_channels=4, + audio_out_channels=4, + audio_num_attention_heads=2, + audio_attention_head_dim=4, + audio_cross_attention_dim=8, + num_layers=2, + qk_norm="rms_norm_across_heads", + caption_channels=text_encoder.config.text_config.hidden_size, + rope_double_precision=False, + rope_type="split", + ) + + torch.manual_seed(0) + connectors = LTX2TextConnectors( + caption_channels=text_encoder.config.text_config.hidden_size, + text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1, + video_connector_num_attention_heads=4, + video_connector_attention_head_dim=8, + video_connector_num_layers=1, + video_connector_num_learnable_registers=None, + audio_connector_num_attention_heads=4, + audio_connector_attention_head_dim=8, + audio_connector_num_layers=1, + audio_connector_num_learnable_registers=None, + connector_rope_base_seq_len=32, + rope_theta=10000.0, + rope_double_precision=False, + causal_temporal_positioning=False, + rope_type="split", + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTX2Video( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(8,), + decoder_block_out_channels=(8,), + layers_per_block=(1,), + decoder_layers_per_block=(1, 1), + spatio_temporal_scaling=(True,), + decoder_spatio_temporal_scaling=(True,), + decoder_inject_noise=(False, False), + downsample_type=("spatial",), + upsample_residual=(False,), + upsample_factor=(1,), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + audio_vae = AutoencoderKLLTX2Audio( + base_channels=4, + output_channels=2, + ch_mult=(1,), + num_res_blocks=1, + attn_resolutions=None, + in_channels=2, + resolution=32, + latent_channels=2, + norm_type="pixel", + causality_axis="height", + dropout=0.0, + mid_block_add_attention=False, + sample_rate=16000, + mel_hop_length=160, + is_causal=True, + mel_bins=8, + ) + + torch.manual_seed(0) + vocoder = LTX2Vocoder( + in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins, + hidden_channels=32, + out_channels=2, + upsample_kernel_sizes=[4, 4], + upsample_factors=[2, 2], + resnet_kernel_sizes=[3], + resnet_dilations=[[1, 3, 5]], + leaky_relu_negative_slope=0.1, + output_sampling_rate=16000, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + components = { + "transformer": transformer, + "vae": vae, + "audio_vae": audio_vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "connectors": connectors, + "vocoder": vocoder, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.rand((1, 3, 32, 32), generator=generator, device=device) + + inputs = { + "image": image, + "prompt": "a robot dancing", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "num_frames": 5, + "frame_rate": 25.0, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = pipe(**inputs) + video = output.frames + audio = output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0294, 0.0498, 0.1269, 0.1135, 0.1639, 0.1116, 0.1730, 0.0931, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_two_stages_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["output_type"] = "latent" + first_stage_output = pipe(**inputs) + video_latent = first_stage_output.frames + audio_latent = first_stage_output.audio + + self.assertEqual(video_latent.shape, (1, 4, 3, 16, 16)) + self.assertEqual(audio_latent.shape, (1, 2, 5, 2)) + self.assertEqual(audio_latent.shape[1], components["vocoder"].config.out_channels) + + inputs["latents"] = video_latent + inputs["audio_latents"] = audio_latent + inputs["output_type"] = "pt" + second_stage_output = pipe(**inputs) + video = second_stage_output.frames + audio = second_stage_output.audio + + self.assertEqual(video.shape, (1, 5, 3, 32, 32)) + self.assertEqual(audio.shape[0], 1) + self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels) + + # fmt: off + expected_video_slice = torch.tensor( + [ + 0.2665, 0.6915, 0.2939, 0.6767, 0.2552, 0.6215, 0.1765, 0.6248, 0.2800, 0.2356, 0.3480, 0.5395, 0.3190, 0.4128, 0.4784, 0.4086 + ] + ) + expected_audio_slice = torch.tensor( + [ + 0.0273, 0.0490, 0.1253, 0.1129, 0.1655, 0.1057, 0.1707, 0.0943, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739 + ] + ) + # fmt: on + + video = video.flatten() + audio = audio.flatten() + generated_video_slice = torch.cat([video[:8], video[-8:]]) + generated_audio_slice = torch.cat([audio[:8], audio[-8:]]) + + assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4) + assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2) diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index 2cb80df81adf..6e8535062a79 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -29,6 +29,7 @@ ) from ...testing_utils import ( + Expectations, backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, @@ -335,7 +336,14 @@ def test_pixart_512(self): image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images image_slice = image[0, -3:, -3:, -1] - expected_slice = np.array([0.0479, 0.0378, 0.0217, 0.0942, 0.064, 0.0791, 0.2073, 0.1975, 0.2017]) + + expected_slices = Expectations( + { + ("xpu", 3): np.array([0.0417, 0.0388, 0.0061, 0.0618, 0.0517, 0.0420, 0.1038, 0.1055, 0.1257]), + ("cuda", None): np.array([0.0479, 0.0378, 0.0217, 0.0942, 0.064, 0.0791, 0.2073, 0.1975, 0.2017]), + } + ) + expected_slice = expected_slices.get_expectation() max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice) self.assertLessEqual(max_diff, 1e-4) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 22570b28841e..b3818e5fe4cc 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -35,7 +35,9 @@ from diffusers.hooks import apply_group_offloading from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.hooks.first_block_cache import FirstBlockCacheConfig +from diffusers.hooks.mag_cache import MagCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook +from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention import AttentionModuleMixin @@ -2924,6 +2926,110 @@ def run_forward(pipe): ) +class TaylorSeerCacheTesterMixin: + taylorseer_cache_config = TaylorSeerCacheConfig( + cache_interval=5, + disable_cache_before_step=10, + max_order=1, + taylor_factors_dtype=torch.bfloat16, + use_lite_mode=True, + ) + + def test_taylorseer_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 50 + return pipe(**inputs)[0] + + # Run inference without TaylorSeerCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with TaylorSeerCache enabled + pipe = create_pipe() + pipe.transformer.enable_cache(self.taylorseer_cache_config) + output = run_forward(pipe).flatten() + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with TaylorSeerCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), ( + "TaylorSeerCache outputs should not differ much." + ) + assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), ( + "Outputs from normal inference and after disabling cache should not differ." + ) + + +class MagCacheTesterMixin: + mag_cache_config = MagCacheConfig( + threshold=0.06, + max_skip_steps=3, + retention_ratio=0.2, + num_inference_steps=50, + mag_ratios=torch.ones(50), + ) + + def test_mag_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + # Match the config steps + inputs["num_inference_steps"] = 50 + return pipe(**inputs)[0] + + # 1. Run inference without MagCache (Baseline) + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # 2. Run inference with MagCache ENABLED + pipe = create_pipe() + pipe.transformer.enable_cache(self.mag_cache_config) + output = run_forward(pipe).flatten() + image_slice_enabled = np.concatenate((output[:8], output[-8:])) + + # 3. Run inference with MagCache DISABLED + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), ( + "MagCache outputs should not differ too much from baseline." + ) + + assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), ( + "Outputs after disabling cache should match original inference exactly." + ) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. diff --git a/tests/pipelines/z_image/test_z_image_img2img.py b/tests/pipelines/z_image/test_z_image_img2img.py new file mode 100644 index 000000000000..91b3025b17e8 --- /dev/null +++ b/tests/pipelines/z_image/test_z_image_img2img.py @@ -0,0 +1,358 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import unittest + +import numpy as np +import torch +from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model + +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + ZImageImg2ImgPipeline, + ZImageTransformer2DModel, +) +from diffusers.utils.testing_utils import floats_tensor + +from ...testing_utils import torch_device +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations +# Cannot use enable_full_determinism() which sets it to True +# Note: Z-Image does not support FP16 inference due to complex64 RoPE embeddings +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(False) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +if hasattr(torch.backends, "cuda"): + torch.backends.cuda.matmul.allow_tf32 = False + + +class ZImageImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ZImageImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "strength", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + def setUp(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def tearDown(self): + super().tearDown() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = ZImageTransformer2DModel( + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=32, + n_layers=2, + n_refiner_layers=1, + n_heads=2, + n_kv_heads=2, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=16, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[8, 4, 4], + axes_lens=[256, 32, 32], + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + block_out_channels=[32, 64], + layers_per_block=1, + latent_channels=16, + norm_num_groups=32, + sample_size=32, + scaling_factor=0.3611, + shift_factor=0.1159, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + config = Qwen3Config( + hidden_size=16, + intermediate_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + text_encoder = Qwen3Model(config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + import random + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "bad quality", + "image": image, + "strength": 0.6, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "cfg_normalization": False, + "cfg_truncation": 1.0, + "height": 32, + "width": 32, + "max_sequence_length": 16, + "output_type": "np", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (32, 32, 3)) + + def test_inference_batch_single_identical(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + + def test_num_images_per_prompt(self): + import inspect + + sig = inspect.signature(self.pipeline_class.__call__) + + if "num_images_per_prompt" not in sig.parameters: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] + + assert images.shape[0] == batch_size * num_images_per_prompt + + del pipe + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.3): + import random + + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + # Generate a larger image for the input + inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu") + output_without_tiling = pipe(**inputs)[0] + + # With tiling (standard AutoencoderKL doesn't accept parameters) + pipe.vae.enable_tiling() + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu") + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4): + # Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance + super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference) + + def test_group_offloading_inference(self): + # Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine. + self.skipTest("Using test_pipeline_level_group_offloading_inference instead") + + def test_save_load_float16(self, expected_max_diff=1e-2): + # Z-Image does not support FP16 due to complex64 RoPE embeddings + self.skipTest("Z-Image does not support FP16 inference") + + def test_float16_inference(self, expected_max_diff=5e-2): + # Z-Image does not support FP16 due to complex64 RoPE embeddings + self.skipTest("Z-Image does not support FP16 inference") + + def test_strength_parameter(self): + """Test that strength parameter affects the output correctly.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + # Test with different strength values + inputs_low_strength = self.get_dummy_inputs(device) + inputs_low_strength["strength"] = 0.2 + + inputs_high_strength = self.get_dummy_inputs(device) + inputs_high_strength["strength"] = 0.8 + + # Both should complete without errors + output_low = pipe(**inputs_low_strength).images[0] + output_high = pipe(**inputs_high_strength).images[0] + + # Outputs should be different (different amount of transformation) + self.assertFalse(np.allclose(output_low, output_high, atol=1e-3)) + + def test_invalid_strength(self): + """Test that invalid strength values raise appropriate errors.""" + device = "cpu" + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + + inputs = self.get_dummy_inputs(device) + + # Test strength < 0 + inputs["strength"] = -0.1 + with self.assertRaises(ValueError): + pipe(**inputs) + + # Test strength > 1 + inputs["strength"] = 1.5 + with self.assertRaises(ValueError): + pipe(**inputs) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index fde3966dec97..031fdc9f9e27 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -288,31 +288,29 @@ def test_config_from_pretrained(self): self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) self.assertTrue(hasattr(linear.weight, "SCB")) + @require_bitsandbytes_version_greater("0.48.0") def test_device_and_dtype_assignment(self): r""" Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. Checks also if other models are casted correctly. """ - with self.assertRaises(ValueError): - # Tries with `str` - self.model_8bit.to("cpu") with self.assertRaises(ValueError): # Tries with a `dtype`` self.model_8bit.to(torch.float16) - with self.assertRaises(ValueError): - # Tries with a `device` - self.model_8bit.to(torch.device(f"{torch_device}:0")) - with self.assertRaises(ValueError): # Tries with a `device` self.model_8bit.float() with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a `dtype` self.model_8bit.half() + # This should work with 0.48.0 + self.model_8bit.to("cpu") + self.model_8bit.to(torch.device(f"{torch_device}:0")) + # Test if we did not break anything self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) input_dict_for_transformer = self.get_dummy_inputs() @@ -837,7 +835,7 @@ def test_serialization_sharded(self): @require_torch_version_greater_equal("2.6.0") -@require_bitsandbytes_version_greater("0.45.5") +@require_bitsandbytes_version_greater("0.48.0") class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase): @property def quantization_config(self): @@ -848,7 +846,7 @@ def quantization_config(self): ) @pytest.mark.xfail( - reason="Test fails because of an offloading problem from Accelerate with confusion in hooks." + reason="Test fails because of a type change when recompiling." " Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details." ) def test_torch_compile(self): @@ -858,6 +856,5 @@ def test_torch_compile(self): def test_torch_compile_with_cpu_offload(self): super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16) - @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") def test_torch_compile_with_group_offload_leaf(self): super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 38997de17b12..7a8e3cc67877 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -35,6 +35,7 @@ from diffusers.quantizers import PipelineQuantizationConfig from ...testing_utils import ( + Expectations, backend_empty_cache, backend_synchronize, enable_full_determinism, @@ -255,9 +256,12 @@ def test_quantization(self): # Cutlass fails to initialize for below # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), # ===== - ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), ]) + if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) # fmt: on for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: @@ -270,6 +274,34 @@ def test_quantization(self): ) self._test_quant_type(quantization_config, expected_slice, model_id) + @unittest.skip("Skipping floatx quantization tests") + def test_floatx_quantization(self): + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): + if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): + quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) + self._test_quant_type( + quantization_config, + np.array( + [ + 0.4648, + 0.5195, + 0.5547, + 0.4180, + 0.4434, + 0.6445, + 0.4316, + 0.4531, + 0.5625, + ] + ), + model_id, + ) + else: + # Make sure the correct error is thrown + with self.assertRaisesRegex(ValueError, "Please downgrade"): + quantization_config = TorchAoConfig(quant_type="fp4", modules_to_not_convert=["x_embedder"]) + def test_int4wo_quant_bfloat16_conversion(self): """ Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. @@ -497,8 +529,23 @@ def test_memory_footprint(self): def test_model_memory_usage(self): model_id = "hf-internal-testing/tiny-flux-pipe" - expected_memory_saving_ratio = 2.0 - + expected_memory_saving_ratios = Expectations( + { + # XPU: For this tiny model, per-tensor overheads (alignment, fragmentation, metadata) become visible. + # While XPU doesn't have the large fixed cuBLAS workspace of A100, these small overheads prevent reaching the ideal 2.0 ratio. + # Observed ~1.27x (158k vs 124k) for model size. + # The runtime memory overhead is ~88k for both bf16 and int8wo. Adding this to model size: (158k+88k)/(124k+88k) ≈ 1.15. + ("xpu", None): 1.15, + # On Ampere, the cuBLAS kernels used for matrix multiplication often allocate a fixed-size workspace. + # Since the tiny-flux model weights are likely smaller than or comparable to this workspace, the total memory is dominated by the workspace. + ("cuda", 8): 1.02, + # On Hopper, TorchAO utilizes newer, highly optimized kernels (via Triton or CUTLASS 3.x) that are designed to be workspace-free or use negligible extra memory. + # Additionally, Triton kernels often handle unaligned memory better, avoiding the padding overhead seen on other backends for tiny tensors. + # This allows it to achieve the near-ideal 2.0x compression ratio. + ("cuda", 9): 2.0, + } + ) + expected_memory_saving_ratio = expected_memory_saving_ratios.get_expectation() inputs = self.get_dummy_tensor_inputs(device=torch_device) transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] @@ -778,8 +825,11 @@ def test_quantization(self): if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), - ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), ]) + if version.parse(importlib.metadata.version("torchao")) <= version.Version("0.14.1"): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), + ]) # fmt: on for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 197c831cb015..ac7e1d3f88b4 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -399,3 +399,32 @@ def test_beta_sigmas(self): def test_exponential_sigmas(self): self.check_over_configs(use_exponential_sigmas=True) + + def test_flow_and_karras_sigmas(self): + self.check_over_configs(use_flow_sigmas=True, use_karras_sigmas=True) + + def test_flow_and_karras_sigmas_values(self): + num_train_timesteps = 1000 + num_inference_steps = 5 + scheduler = UniPCMultistepScheduler( + sigma_min=0.01, + sigma_max=200.0, + use_flow_sigmas=True, + use_karras_sigmas=True, + num_train_timesteps=num_train_timesteps, + ) + scheduler.set_timesteps(num_inference_steps=num_inference_steps) + + expected_sigmas = [ + 0.9950248599052429, + 0.9787454605102539, + 0.8774884343147278, + 0.3604971766471863, + 0.009900986216962337, + 0.0, # 0 appended as default + ] + expected_sigmas = torch.tensor(expected_sigmas) + expected_timesteps = (expected_sigmas * num_train_timesteps).to(torch.int64) + expected_timesteps = expected_timesteps[0:-1] + self.assertTrue(torch.allclose(scheduler.sigmas, expected_sigmas)) + self.assertTrue(torch.all(expected_timesteps == scheduler.timesteps)) diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 6ed7e3467d7f..43f3253925df 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -38,6 +38,7 @@ is_gguf_available, is_kernels_available, is_note_seq_available, + is_nvidia_modelopt_version, is_onnx_available, is_opencv_available, is_optimum_quanto_available, @@ -130,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs): return True +def assert_tensors_close( + actual: "torch.Tensor", + expected: "torch.Tensor", + atol: float = 1e-5, + rtol: float = 1e-5, + msg: str = "", +) -> None: + """ + Assert that two tensors are close within tolerance. + + Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected| + Provides concise, actionable error messages without dumping full tensors. + + Args: + actual: The actual tensor from the computation. + expected: The expected tensor to compare against. + atol: Absolute tolerance. + rtol: Relative tolerance. + msg: Optional message prefix for the assertion error. + + Raises: + AssertionError: If tensors have different shapes or values exceed tolerance. + + Example: + >>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass") + """ + if not is_torch_available(): + raise ValueError("PyTorch needs to be installed to use this function.") + + if actual.shape != expected.shape: + raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}") + + if not torch.allclose(actual, expected, atol=atol, rtol=rtol): + abs_diff = (actual - expected).abs() + max_diff = abs_diff.max().item() + + flat_idx = abs_diff.argmax().item() + max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist()) + + threshold = atol + rtol * expected.abs() + mismatched = (abs_diff > threshold).sum().item() + total = actual.numel() + + raise AssertionError( + f"{msg}\n" + f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n" + f" Max diff: {max_diff:.6e} at index {max_idx}\n" + f" Actual: {actual.flatten()[flat_idx].item():.6e}\n" + f" Expected: {expected.flatten()[flat_idx].item():.6e}\n" + f" atol: {atol:.6e}, rtol: {rtol:.6e}" + ) + + def numpy_cosine_similarity_distance(a, b): similarity = np.dot(a, b) / (norm(a) * norm(b)) distance = 1.0 - similarity.mean() @@ -241,7 +295,6 @@ def parse_flag_from_env(key, default=False): _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) _run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False) -_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False) def floats_tensor(shape, scale=1.0, rng=None, name=None): @@ -282,12 +335,155 @@ def nightly(test_case): def is_torch_compile(test_case): """ - Decorator marking a test that runs compile tests in the diffusers CI. + Decorator marking a test as a torch.compile test. These tests can be filtered using: + pytest -m "not compile" to skip + pytest -m compile to run only these tests + """ + return pytest.mark.compile(test_case) + + +def is_single_file(test_case): + """ + Decorator marking a test as a single file loading test. These tests can be filtered using: + pytest -m "not single_file" to skip + pytest -m single_file to run only these tests + """ + return pytest.mark.single_file(test_case) - Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them. +def is_lora(test_case): """ - return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case) + Decorator marking a test as a LoRA test. These tests can be filtered using: + pytest -m "not lora" to skip + pytest -m lora to run only these tests + """ + return pytest.mark.lora(test_case) + + +def is_ip_adapter(test_case): + """ + Decorator marking a test as an IP Adapter test. These tests can be filtered using: + pytest -m "not ip_adapter" to skip + pytest -m ip_adapter to run only these tests + """ + return pytest.mark.ip_adapter(test_case) + + +def is_training(test_case): + """ + Decorator marking a test as a training test. These tests can be filtered using: + pytest -m "not training" to skip + pytest -m training to run only these tests + """ + return pytest.mark.training(test_case) + + +def is_attention(test_case): + """ + Decorator marking a test as an attention test. These tests can be filtered using: + pytest -m "not attention" to skip + pytest -m attention to run only these tests + """ + return pytest.mark.attention(test_case) + + +def is_memory(test_case): + """ + Decorator marking a test as a memory optimization test. These tests can be filtered using: + pytest -m "not memory" to skip + pytest -m memory to run only these tests + """ + return pytest.mark.memory(test_case) + + +def is_cpu_offload(test_case): + """ + Decorator marking a test as a CPU offload test. These tests can be filtered using: + pytest -m "not cpu_offload" to skip + pytest -m cpu_offload to run only these tests + """ + return pytest.mark.cpu_offload(test_case) + + +def is_group_offload(test_case): + """ + Decorator marking a test as a group offload test. These tests can be filtered using: + pytest -m "not group_offload" to skip + pytest -m group_offload to run only these tests + """ + return pytest.mark.group_offload(test_case) + + +def is_quantization(test_case): + """ + Decorator marking a test as a quantization test. These tests can be filtered using: + pytest -m "not quantization" to skip + pytest -m quantization to run only these tests + """ + return pytest.mark.quantization(test_case) + + +def is_bitsandbytes(test_case): + """ + Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using: + pytest -m "not bitsandbytes" to skip + pytest -m bitsandbytes to run only these tests + """ + return pytest.mark.bitsandbytes(test_case) + + +def is_quanto(test_case): + """ + Decorator marking a test as a Quanto quantization test. These tests can be filtered using: + pytest -m "not quanto" to skip + pytest -m quanto to run only these tests + """ + return pytest.mark.quanto(test_case) + + +def is_torchao(test_case): + """ + Decorator marking a test as a TorchAO quantization test. These tests can be filtered using: + pytest -m "not torchao" to skip + pytest -m torchao to run only these tests + """ + return pytest.mark.torchao(test_case) + + +def is_gguf(test_case): + """ + Decorator marking a test as a GGUF quantization test. These tests can be filtered using: + pytest -m "not gguf" to skip + pytest -m gguf to run only these tests + """ + return pytest.mark.gguf(test_case) + + +def is_modelopt(test_case): + """ + Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using: + pytest -m "not modelopt" to skip + pytest -m modelopt to run only these tests + """ + return pytest.mark.modelopt(test_case) + + +def is_context_parallel(test_case): + """ + Decorator marking a test as a context parallel inference test. These tests can be filtered using: + pytest -m "not context_parallel" to skip + pytest -m context_parallel to run only these tests + """ + return pytest.mark.context_parallel(test_case) + + +def is_cache(test_case): + """ + Decorator marking a test as a cache test. These tests can be filtered using: + pytest -m "not cache" to skip + pytest -m cache to run only these tests + """ + return pytest.mark.cache(test_case) def require_torch(test_case): @@ -650,6 +846,16 @@ def decorator(test_case): return decorator +def require_modelopt_version_greater_or_equal(modelopt_version): + def decorator(test_case): + return pytest.mark.skipif( + not is_nvidia_modelopt_version(">=", modelopt_version), + reason=f"Test requires modelopt with version greater than {modelopt_version}.", + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend @@ -1424,6 +1630,8 @@ def _get_expected_safetensors_files( offload_to_disk_path: str, offload_type: str, num_blocks_per_group: Optional[int] = None, + block_modules: Optional[List[str]] = None, + module_prefix: str = "", ) -> Set[str]: expected_files = set() @@ -1435,23 +1643,36 @@ def get_hashed_filename(group_id: str) -> str: if num_blocks_per_group is None: raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") - # Handle groups of ModuleList and Sequential blocks + block_modules_set = set(block_modules) if block_modules is not None else set() + + modules_with_group_offloading = set() unmatched_modules = [] for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - unmatched_modules.append(module) - continue + if name in block_modules_set: + new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}." + submodule_files = _get_expected_safetensors_files( + submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix + ) + expected_files.update(submodule_files) + modules_with_group_offloading.add(name) + + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] + if not current_modules: + continue + group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}" + expected_files.add(get_hashed_filename(group_id)) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + unmatched_modules.append(submodule) - for i in range(0, len(submodule), num_blocks_per_group): - current_modules = submodule[i : i + num_blocks_per_group] - if not current_modules: - continue - group_id = f"{name}_{i}_{i + len(current_modules) - 1}" - expected_files.add(get_hashed_filename(group_id)) + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) - # Handle the group for unmatched top-level modules and parameters - for module in unmatched_modules: - expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group")) + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group")) elif offload_type == "leaf_level": # Handle leaf-level module groups @@ -1492,12 +1713,13 @@ def _check_safetensors_serialization( offload_to_disk_path: str, offload_type: str, num_blocks_per_group: Optional[int] = None, + block_modules: Optional[List[str]] = None, ) -> bool: if not os.path.isdir(offload_to_disk_path): return False, None, None expected_files = _get_expected_safetensors_files( - module, offload_to_disk_path, offload_type, num_blocks_per_group + module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules ) actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) missing_files = expected_files - actual_files diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py new file mode 100644 index 000000000000..11acd2175e21 --- /dev/null +++ b/utils/generate_model_tests.py @@ -0,0 +1,592 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility script to generate test suites for diffusers model classes. + +Usage: + python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py + +This will analyze the model file and generate a test file with appropriate +test classes based on the model's mixins and attributes. +""" + +import argparse +import ast +import sys +from pathlib import Path + + +MIXIN_TO_TESTER = { + "ModelMixin": "ModelTesterMixin", + "PeftAdapterMixin": "LoraTesterMixin", +} + +ATTRIBUTE_TO_TESTER = { + "_cp_plan": "ContextParallelTesterMixin", + "_supports_gradient_checkpointing": "TrainingTesterMixin", +} + +ALWAYS_INCLUDE_TESTERS = [ + "ModelTesterMixin", + "MemoryTesterMixin", + "TorchCompileTesterMixin", +] + +# Attention-related class names that indicate the model uses attention +ATTENTION_INDICATORS = { + "AttentionMixin", + "AttentionModuleMixin", +} + +OPTIONAL_TESTERS = [ + # Quantization testers + ("BitsAndBytesTesterMixin", "bnb"), + ("QuantoTesterMixin", "quanto"), + ("TorchAoTesterMixin", "torchao"), + ("GGUFTesterMixin", "gguf"), + ("ModelOptTesterMixin", "modelopt"), + # Quantization compile testers + ("BitsAndBytesCompileTesterMixin", "bnb_compile"), + ("QuantoCompileTesterMixin", "quanto_compile"), + ("TorchAoCompileTesterMixin", "torchao_compile"), + ("GGUFCompileTesterMixin", "gguf_compile"), + ("ModelOptCompileTesterMixin", "modelopt_compile"), + # Cache testers + ("PyramidAttentionBroadcastTesterMixin", "pab_cache"), + ("FirstBlockCacheTesterMixin", "fbc_cache"), + ("FasterCacheTesterMixin", "faster_cache"), + # Other testers + ("SingleFileTesterMixin", "single_file"), + ("IPAdapterTesterMixin", "ip_adapter"), +] + + +class ModelAnalyzer(ast.NodeVisitor): + def __init__(self): + self.model_classes = [] + self.current_class = None + self.imports = set() + + def visit_Import(self, node: ast.Import): + for alias in node.names: + self.imports.add(alias.name.split(".")[-1]) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom): + for alias in node.names: + self.imports.add(alias.name) + self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef): + base_names = [] + for base in node.bases: + if isinstance(base, ast.Name): + base_names.append(base.id) + elif isinstance(base, ast.Attribute): + base_names.append(base.attr) + + if "ModelMixin" in base_names: + class_info = { + "name": node.name, + "bases": base_names, + "attributes": {}, + "has_forward": False, + "init_params": [], + } + + for item in node.body: + if isinstance(item, ast.Assign): + for target in item.targets: + if isinstance(target, ast.Name): + attr_name = target.id + if attr_name.startswith("_"): + class_info["attributes"][attr_name] = self._get_value(item.value) + + elif isinstance(item, ast.FunctionDef): + if item.name == "forward": + class_info["has_forward"] = True + class_info["forward_params"] = self._extract_func_params(item) + elif item.name == "__init__": + class_info["init_params"] = self._extract_func_params(item) + + self.model_classes.append(class_info) + + self.generic_visit(node) + + def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]: + params = [] + args = func_node.args + + num_defaults = len(args.defaults) + num_args = len(args.args) + first_default_idx = num_args - num_defaults + + for i, arg in enumerate(args.args): + if arg.arg == "self": + continue + + param_info = {"name": arg.arg, "type": None, "default": None} + + if arg.annotation: + param_info["type"] = self._get_annotation_str(arg.annotation) + + default_idx = i - first_default_idx + if default_idx >= 0 and default_idx < len(args.defaults): + param_info["default"] = self._get_value(args.defaults[default_idx]) + + params.append(param_info) + + return params + + def _get_annotation_str(self, node) -> str: + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Constant): + return repr(node.value) + elif isinstance(node, ast.Subscript): + base = self._get_annotation_str(node.value) + if isinstance(node.slice, ast.Tuple): + args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts) + else: + args = self._get_annotation_str(node.slice) + return f"{base}[{args}]" + elif isinstance(node, ast.Attribute): + return f"{self._get_annotation_str(node.value)}.{node.attr}" + elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + left = self._get_annotation_str(node.left) + right = self._get_annotation_str(node.right) + return f"{left} | {right}" + elif isinstance(node, ast.Tuple): + return ", ".join(self._get_annotation_str(el) for el in node.elts) + return "Any" + + def _get_value(self, node): + if isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Name): + if node.id == "None": + return None + elif node.id == "True": + return True + elif node.id == "False": + return False + return node.id + elif isinstance(node, ast.List): + return [self._get_value(el) for el in node.elts] + elif isinstance(node, ast.Dict): + return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)} + return "" + + +def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]: + with open(filepath) as f: + source = f.read() + + tree = ast.parse(source) + analyzer = ModelAnalyzer() + analyzer.visit(tree) + + return analyzer.model_classes, analyzer.imports + + +def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]: + testers = list(ALWAYS_INCLUDE_TESTERS) + + for base in model_info["bases"]: + if base in MIXIN_TO_TESTER: + tester = MIXIN_TO_TESTER[base] + if tester not in testers: + testers.append(tester) + + for attr, tester in ATTRIBUTE_TO_TESTER.items(): + if attr in model_info["attributes"]: + value = model_info["attributes"][attr] + if value is not None and value is not False: + if tester not in testers: + testers.append(tester) + + if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None: + if "ContextParallelTesterMixin" not in testers: + testers.append("ContextParallelTesterMixin") + + # Include AttentionTesterMixin if the model imports attention-related classes + if imports & ATTENTION_INDICATORS: + testers.append("AttentionTesterMixin") + + for tester, flag in OPTIONAL_TESTERS: + if flag in include_optional: + if tester not in testers: + testers.append(tester) + + return testers + + +def generate_config_class(model_info: dict, model_name: str) -> str: + class_name = f"{model_name}TesterConfig" + model_class = model_info["name"] + forward_params = model_info.get("forward_params", []) + init_params = model_info.get("init_params", []) + + lines = [ + f"class {class_name}:", + " @property", + " def model_class(self):", + f" return {model_class}", + "", + " @property", + " def pretrained_model_name_or_path(self):", + ' return "" # TODO: Set Hub repository ID', + "", + " @property", + " def pretrained_model_kwargs(self):", + ' return {"subfolder": "transformer"}', + "", + " @property", + " def generator(self):", + ' return torch.Generator("cpu").manual_seed(0)', + "", + " def get_init_dict(self) -> dict[str, int | list[int]]:", + ] + + if init_params: + lines.append(" # __init__ parameters:") + for param in init_params: + type_str = f": {param['type']}" if param["type"] else "" + default_str = f" = {param['default']}" if param["default"] is not None else "" + lines.append(f" # {param['name']}{type_str}{default_str}") + + lines.extend( + [ + " return {}", + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + ] + ) + + if forward_params: + lines.append(" # forward() parameters:") + for param in forward_params: + type_str = f": {param['type']}" if param["type"] else "" + default_str = f" = {param['default']}" if param["default"] is not None else "" + lines.append(f" # {param['name']}{type_str}{default_str}") + + lines.extend( + [ + " # TODO: Fill in dummy inputs", + " return {}", + "", + " @property", + " def input_shape(self) -> tuple[int, ...]:", + " return (1, 1)", + "", + " @property", + " def output_shape(self) -> tuple[int, ...]:", + " return (1, 1)", + ] + ) + + return "\n".join(lines) + + +def generate_test_class(model_name: str, config_class: str, tester: str) -> str: + tester_short = tester.replace("TesterMixin", "") + class_name = f"Test{model_name}{tester_short}" + + lines = [f"class {class_name}({config_class}, {tester}):"] + + if tester == "TorchCompileTesterMixin": + lines.extend( + [ + " @property", + " def different_shapes_for_compilation(self):", + " return [(4, 4), (4, 8), (8, 8)]", + "", + " def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:", + " # TODO: Implement dynamic input generation", + " return {}", + ] + ) + elif tester == "IPAdapterTesterMixin": + lines.extend( + [ + " @property", + " def ip_adapter_processor_cls(self):", + " return None # TODO: Set processor class", + "", + " def modify_inputs_for_ip_adapter(self, model, inputs_dict):", + " # TODO: Add IP adapter image embeds to inputs", + " return inputs_dict", + "", + " def create_ip_adapter_state_dict(self, model):", + " # TODO: Create IP adapter state dict", + " return {}", + ] + ) + elif tester == "SingleFileTesterMixin": + lines.extend( + [ + " @property", + " def ckpt_path(self):", + ' return "" # TODO: Set checkpoint path', + "", + " @property", + " def alternate_ckpt_paths(self):", + " return []", + "", + " @property", + " def pretrained_model_name_or_path(self):", + ' return "" # TODO: Set Hub repository ID', + ] + ) + elif tester == "GGUFTesterMixin": + lines.extend( + [ + " @property", + " def gguf_filename(self):", + ' return "" # TODO: Set GGUF filename', + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization tests", + " return {}", + ] + ) + elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]: + lines.extend( + [ + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization tests", + " return {}", + ] + ) + elif tester in [ + "BitsAndBytesCompileTesterMixin", + "QuantoCompileTesterMixin", + "TorchAoCompileTesterMixin", + "ModelOptCompileTesterMixin", + ]: + lines.extend( + [ + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization compile tests", + " return {}", + ] + ) + elif tester == "GGUFCompileTesterMixin": + lines.extend( + [ + " @property", + " def gguf_filename(self):", + ' return "" # TODO: Set GGUF filename', + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization compile tests", + " return {}", + ] + ) + elif tester in [ + "PyramidAttentionBroadcastTesterMixin", + "FirstBlockCacheTesterMixin", + "FasterCacheTesterMixin", + ]: + lines.append(" pass") + elif tester == "LoraHotSwappingForModelTesterMixin": + lines.extend( + [ + " @property", + " def different_shapes_for_compilation(self):", + " return [(4, 4), (4, 8), (8, 8)]", + "", + " def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:", + " # TODO: Implement dynamic input generation", + " return {}", + ] + ) + else: + lines.append(" pass") + + return "\n".join(lines) + + +def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str: + model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "") + testers = determine_testers(model_info, include_optional, imports) + tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"}) + + lines = [ + "# coding=utf-8", + "# Copyright 2025 HuggingFace Inc.", + "#", + '# Licensed under the Apache License, Version 2.0 (the "License");', + "# you may not use this file except in compliance with the License.", + "# You may obtain a copy of the License at", + "#", + "# http://www.apache.org/licenses/LICENSE-2.0", + "#", + "# Unless required by applicable law or agreed to in writing, software", + '# distributed under the License is distributed on an "AS IS" BASIS,', + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "# See the License for the specific language governing permissions and", + "# limitations under the License.", + "", + "import torch", + "", + f"from diffusers import {model_info['name']}", + "from diffusers.utils.torch_utils import randn_tensor", + "", + "from ...testing_utils import enable_full_determinism, torch_device", + ] + + if "LoraTesterMixin" in testers: + lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin") + + lines.extend( + [ + "from ..testing_utils import (", + *[f" {tester}," for tester in sorted(tester_imports)], + ")", + "", + "", + "enable_full_determinism()", + "", + "", + ] + ) + + config_class = f"{model_name}TesterConfig" + lines.append(generate_config_class(model_info, model_name)) + lines.append("") + lines.append("") + + for tester in testers: + lines.append(generate_test_class(model_name, config_class, tester)) + lines.append("") + lines.append("") + + if "LoraTesterMixin" in testers: + lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin")) + lines.append("") + lines.append("") + + return "\n".join(lines).rstrip() + "\n" + + +def get_test_output_path(model_filepath: str) -> str: + path = Path(model_filepath) + model_filename = path.stem + + if "transformers" in path.parts: + return f"tests/models/transformers/test_models_{model_filename}.py" + elif "unets" in path.parts: + return f"tests/models/unets/test_models_{model_filename}.py" + elif "autoencoders" in path.parts: + return f"tests/models/autoencoders/test_models_{model_filename}.py" + else: + return f"tests/models/test_models_{model_filename}.py" + + +def main(): + parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class") + parser.add_argument( + "model_filepath", + type=str, + help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)", + ) + parser.add_argument( + "--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)" + ) + parser.add_argument( + "--include", + "-i", + type=str, + nargs="*", + default=[], + choices=[ + "bnb", + "quanto", + "torchao", + "gguf", + "modelopt", + "bnb_compile", + "quanto_compile", + "torchao_compile", + "gguf_compile", + "modelopt_compile", + "pab_cache", + "fbc_cache", + "faster_cache", + "single_file", + "ip_adapter", + "all", + ], + help="Optional testers to include", + ) + parser.add_argument( + "--class-name", + "-c", + type=str, + default=None, + help="Specific model class to generate tests for (default: first model class found)", + ) + parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file") + + args = parser.parse_args() + + if not Path(args.model_filepath).exists(): + print(f"Error: File not found: {args.model_filepath}", file=sys.stderr) + sys.exit(1) + + model_classes, imports = analyze_model_file(args.model_filepath) + + if not model_classes: + print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr) + sys.exit(1) + + if args.class_name: + model_info = next((m for m in model_classes if m["name"] == args.class_name), None) + if not model_info: + available = [m["name"] for m in model_classes] + print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr) + sys.exit(1) + else: + model_info = model_classes[0] + if len(model_classes) > 1: + print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr) + print("Use --class-name to specify a different class", file=sys.stderr) + + include_optional = args.include + if "all" in include_optional: + include_optional = [flag for _, flag in OPTIONAL_TESTERS] + + generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports) + + if args.dry_run: + print(generated_code) + else: + output_path = args.output or get_test_output_path(args.model_filepath) + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + f.write(generated_code) + + print(f"Generated test file: {output_path}") + print(f"Model class: {model_info['name']}") + print(f"Detected attributes: {list(model_info['attributes'].keys())}") + + +if __name__ == "__main__": + main() diff --git a/utils/modular_auto_docstring.py b/utils/modular_auto_docstring.py new file mode 100644 index 000000000000..fc4a82f98ea1 --- /dev/null +++ b/utils/modular_auto_docstring.py @@ -0,0 +1,352 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Auto Docstring Generator for Modular Pipeline Blocks + +This script scans Python files for classes that have `# auto_docstring` comment above them +and inserts/updates the docstring from the class's `doc` property. + +Run from the root of the repo: + python utils/modular_auto_docstring.py [path] [--fix_and_overwrite] + +Examples: + # Check for auto_docstring markers (will error if found without proper docstring) + python utils/modular_auto_docstring.py + + # Check specific directory + python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/ + + # Fix and overwrite the docstrings + python utils/modular_auto_docstring.py --fix_and_overwrite + +Usage in code: + # auto_docstring + class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + # docstring will be automatically inserted here + + @property + def doc(self): + return "Your docstring content..." +""" + +import argparse +import ast +import glob +import importlib +import os +import re +import subprocess +import sys + + +# All paths are set with the intent you should run this script from the root of the repo +DIFFUSERS_PATH = "src/diffusers" +REPO_PATH = "." + +# Pattern to match the auto_docstring comment +AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$") + + +def setup_diffusers_import(): + """Setup import path to use the local diffusers module.""" + src_path = os.path.join(REPO_PATH, "src") + if src_path not in sys.path: + sys.path.insert(0, src_path) + + +def get_module_from_filepath(filepath: str) -> str: + """Convert a filepath to a module name.""" + filepath = os.path.normpath(filepath) + + if filepath.startswith("src" + os.sep): + filepath = filepath[4:] + + if filepath.endswith(".py"): + filepath = filepath[:-3] + + module_name = filepath.replace(os.sep, ".") + return module_name + + +def load_module(filepath: str): + """Load a module from filepath.""" + setup_diffusers_import() + module_name = get_module_from_filepath(filepath) + + try: + module = importlib.import_module(module_name) + return module + except Exception as e: + print(f"Warning: Could not import module {module_name}: {e}") + return None + + +def get_doc_from_class(module, class_name: str) -> str: + """Get the doc property from an instantiated class.""" + if module is None: + return None + + cls = getattr(module, class_name, None) + if cls is None: + return None + + try: + instance = cls() + if hasattr(instance, "doc"): + return instance.doc + except Exception as e: + print(f"Warning: Could not instantiate {class_name}: {e}") + + return None + + +def find_auto_docstring_classes(filepath: str) -> list: + """ + Find all classes in a file that have # auto_docstring comment above them. + + Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line) + """ + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Parse AST to find class locations and their docstrings + content = "".join(lines) + try: + tree = ast.parse(content) + except SyntaxError as e: + print(f"Syntax error in {filepath}: {e}") + return [] + + # Build a map of class_name -> (class_line, has_docstring, docstring_end_line) + class_info = {} + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + has_docstring = False + docstring_end_line = node.lineno # default to class line + + if node.body and isinstance(node.body[0], ast.Expr): + first_stmt = node.body[0] + if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str): + has_docstring = True + docstring_end_line = first_stmt.end_lineno or first_stmt.lineno + + class_info[node.name] = (node.lineno, has_docstring, docstring_end_line) + + # Now scan for # auto_docstring comments + classes_to_update = [] + + for i, line in enumerate(lines): + if AUTO_DOCSTRING_PATTERN.match(line): + # Found the marker, look for class definition on next non-empty, non-comment line + j = i + 1 + while j < len(lines): + next_line = lines[j].strip() + if next_line and not next_line.startswith("#"): + break + j += 1 + + if j < len(lines) and lines[j].strip().startswith("class "): + # Extract class name + match = re.match(r"class\s+(\w+)", lines[j].strip()) + if match: + class_name = match.group(1) + if class_name in class_info: + class_line, has_docstring, docstring_end_line = class_info[class_name] + classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line)) + + return classes_to_update + + +def strip_class_name_line(doc: str, class_name: str) -> str: + """Remove the 'class ClassName' line from the doc if present.""" + lines = doc.strip().split("\n") + if lines and lines[0].strip() == f"class {class_name}": + # Remove the class line and any blank line following it + lines = lines[1:] + while lines and not lines[0].strip(): + lines = lines[1:] + return "\n".join(lines) + + +def format_docstring(doc: str, indent: str = " ") -> str: + """Format a doc string as a properly indented docstring.""" + lines = doc.strip().split("\n") + + if len(lines) == 1: + return f'{indent}"""{lines[0]}"""\n' + else: + result = [f'{indent}"""\n'] + for line in lines: + if line.strip(): + result.append(f"{indent}{line}\n") + else: + result.append("\n") + result.append(f'{indent}"""\n') + return "".join(result) + + +def run_ruff_format(filepath: str): + """Run ruff check --fix, ruff format, and doc-builder style on a file to ensure consistent formatting.""" + try: + # First run ruff check --fix to fix any linting issues (including line length) + subprocess.run( + ["ruff", "check", "--fix", filepath], + check=False, # Don't fail if there are unfixable issues + capture_output=True, + text=True, + ) + # Then run ruff format for code formatting + subprocess.run( + ["ruff", "format", filepath], + check=True, + capture_output=True, + text=True, + ) + # Finally run doc-builder style for docstring formatting + subprocess.run( + ["doc-builder", "style", filepath, "--max_len", "119"], + check=False, # Don't fail if doc-builder has issues + capture_output=True, + text=True, + ) + print(f"Formatted {filepath}") + except subprocess.CalledProcessError as e: + print(f"Warning: formatting failed for {filepath}: {e.stderr}") + except FileNotFoundError as e: + print(f"Warning: tool not found ({e}). Skipping formatting.") + except Exception as e: + print(f"Warning: unexpected error formatting {filepath}: {e}") + + +def get_existing_docstring(lines: list, class_line: int, docstring_end_line: int) -> str: + """Extract the existing docstring content from lines.""" + # class_line is 1-indexed, docstring starts at class_line (0-indexed: class_line) + # and ends at docstring_end_line (1-indexed, inclusive) + docstring_lines = lines[class_line:docstring_end_line] + return "".join(docstring_lines) + + +def process_file(filepath: str, overwrite: bool = False) -> list: + """ + Process a file and find/insert docstrings for # auto_docstring marked classes. + + Returns list of classes that need updating. + """ + classes_to_update = find_auto_docstring_classes(filepath) + + if not classes_to_update: + return [] + + if not overwrite: + # Check mode: only verify that docstrings exist + # Content comparison is not reliable due to formatting differences + classes_needing_update = [] + for class_name, class_line, has_docstring, docstring_end_line in classes_to_update: + if not has_docstring: + # No docstring exists, needs update + classes_needing_update.append((filepath, class_name, class_line)) + return classes_needing_update + + # Load the module to get doc properties + module = load_module(filepath) + + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Process in reverse order to maintain line numbers + updated = False + for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update): + doc = get_doc_from_class(module, class_name) + + if doc is None: + print(f"Warning: Could not get doc for {class_name} in {filepath}") + continue + + # Remove the "class ClassName" line since it's redundant in a docstring + doc = strip_class_name_line(doc, class_name) + + # Format the new docstring with 4-space indent + new_docstring = format_docstring(doc, " ") + + if has_docstring: + # Replace existing docstring (line after class definition to docstring_end_line) + # class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line + lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:] + else: + # Insert new docstring right after class definition line + # class_line is 1-indexed, so lines[class_line-1] is the class line + # Insert at position class_line (which is right after the class line) + lines = lines[:class_line] + [new_docstring] + lines[class_line:] + + updated = True + print(f"Updated docstring for {class_name} in {filepath}") + + if updated: + with open(filepath, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + # Run ruff format to ensure consistent line wrapping + run_ruff_format(filepath) + + return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] + + +def check_auto_docstrings(path: str = None, overwrite: bool = False): + """ + Check all files for # auto_docstring markers and optionally fix them. + """ + if path is None: + path = DIFFUSERS_PATH + + if os.path.isfile(path): + all_files = [path] + else: + all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True) + + all_markers = [] + + for filepath in all_files: + markers = process_file(filepath, overwrite) + all_markers.extend(markers) + + if not overwrite and len(all_markers) > 0: + message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers]) + raise ValueError( + f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n" + f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them." + ) + + if overwrite and len(all_markers) > 0: + print(f"\nProcessed {len(all_markers)} docstring(s).") + elif not overwrite and len(all_markers) == 0: + print("All # auto_docstring markers have valid docstrings.") + elif len(all_markers) == 0: + print("No # auto_docstring markers found.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Check and fix # auto_docstring markers in modular pipeline blocks", + ) + parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)") + parser.add_argument( + "--fix_and_overwrite", + action="store_true", + help="Whether to fix the docstrings by inserting them from doc property.", + ) + + args = parser.parse_args() + + check_auto_docstrings(args.path, args.fix_and_overwrite)